Home | History | Annotate | Download | only in sksl
      1 /*
      2  * Copyright 2016 Google Inc.
      3  *
      4  * Use of this source code is governed by a BSD-style license that can be
      5  * found in the LICENSE file.
      6  */
      7 
      8 #include "SkSLMetalCodeGenerator.h"
      9 
     10 #include "SkSLCompiler.h"
     11 #include "ir/SkSLExpressionStatement.h"
     12 #include "ir/SkSLExtension.h"
     13 #include "ir/SkSLIndexExpression.h"
     14 #include "ir/SkSLModifiersDeclaration.h"
     15 #include "ir/SkSLNop.h"
     16 #include "ir/SkSLVariableReference.h"
     17 
     18 namespace SkSL {
     19 
     20 void MetalCodeGenerator::write(const char* s) {
     21     if (!s[0]) {
     22         return;
     23     }
     24     if (fAtLineStart) {
     25         for (int i = 0; i < fIndentation; i++) {
     26             fOut->writeText("    ");
     27         }
     28     }
     29     fOut->writeText(s);
     30     fAtLineStart = false;
     31 }
     32 
     33 void MetalCodeGenerator::writeLine(const char* s) {
     34     this->write(s);
     35     fOut->writeText(fLineEnding);
     36     fAtLineStart = true;
     37 }
     38 
     39 void MetalCodeGenerator::write(const String& s) {
     40     this->write(s.c_str());
     41 }
     42 
     43 void MetalCodeGenerator::writeLine(const String& s) {
     44     this->writeLine(s.c_str());
     45 }
     46 
     47 void MetalCodeGenerator::writeLine() {
     48     this->writeLine("");
     49 }
     50 
     51 void MetalCodeGenerator::writeExtension(const Extension& ext) {
     52     this->writeLine("#extension " + ext.fName + " : enable");
     53 }
     54 
     55 void MetalCodeGenerator::writeType(const Type& type) {
     56     switch (type.kind()) {
     57         case Type::kStruct_Kind:
     58             for (const Type* search : fWrittenStructs) {
     59                 if (*search == type) {
     60                     // already written
     61                     this->write(type.name());
     62                     return;
     63                 }
     64             }
     65             fWrittenStructs.push_back(&type);
     66             this->writeLine("struct " + type.name() + " {");
     67             fIndentation++;
     68             for (const auto& f : type.fields()) {
     69                 this->writeModifiers(f.fModifiers, false);
     70                 // sizes (which must be static in structs) are part of the type name here
     71                 this->writeType(*f.fType);
     72                 this->writeLine(" " + f.fName + ";");
     73             }
     74             fIndentation--;
     75             this->write("}");
     76             break;
     77         case Type::kVector_Kind:
     78             this->writeType(type.componentType());
     79             this->write(to_string(type.columns()));
     80             break;
     81         default:
     82             this->write(type.name());
     83     }
     84 }
     85 
     86 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) {
     87     switch (expr.fKind) {
     88         case Expression::kBinary_Kind:
     89             this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence);
     90             break;
     91         case Expression::kBoolLiteral_Kind:
     92             this->writeBoolLiteral((BoolLiteral&) expr);
     93             break;
     94         case Expression::kConstructor_Kind:
     95             this->writeConstructor((Constructor&) expr);
     96             break;
     97         case Expression::kIntLiteral_Kind:
     98             this->writeIntLiteral((IntLiteral&) expr);
     99             break;
    100         case Expression::kFieldAccess_Kind:
    101             this->writeFieldAccess(((FieldAccess&) expr));
    102             break;
    103         case Expression::kFloatLiteral_Kind:
    104             this->writeFloatLiteral(((FloatLiteral&) expr));
    105             break;
    106         case Expression::kFunctionCall_Kind:
    107             this->writeFunctionCall((FunctionCall&) expr);
    108             break;
    109         case Expression::kPrefix_Kind:
    110             this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence);
    111             break;
    112         case Expression::kPostfix_Kind:
    113             this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence);
    114             break;
    115         case Expression::kSetting_Kind:
    116             this->writeSetting((Setting&) expr);
    117             break;
    118         case Expression::kSwizzle_Kind:
    119             this->writeSwizzle((Swizzle&) expr);
    120             break;
    121         case Expression::kVariableReference_Kind:
    122             this->writeVariableReference((VariableReference&) expr);
    123             break;
    124         case Expression::kTernary_Kind:
    125             this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence);
    126             break;
    127         case Expression::kIndex_Kind:
    128             this->writeIndexExpression((IndexExpression&) expr);
    129             break;
    130         default:
    131             ABORT("unsupported expression: %s", expr.description().c_str());
    132     }
    133 }
    134 
    135 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) {
    136     if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) {
    137         this->write("atan2");
    138     } else {
    139         this->write(c.fFunction.fName);
    140     }
    141     this->write("(");
    142     const char* separator = "";
    143     if (this->requirements(c.fFunction) & kInputs_Requirement) {
    144         this->write("_in");
    145         separator = ", ";
    146     }
    147     if (this->requirements(c.fFunction) & kOutputs_Requirement) {
    148         this->write(separator);
    149         this->write("_out");
    150         separator = ", ";
    151     }
    152     if (this->requirements(c.fFunction) & kUniforms_Requirement) {
    153         this->write(separator);
    154         this->write("_uniforms");
    155         separator = ", ";
    156     }
    157     for (size_t i = 0; i < c.fArguments.size(); ++i) {
    158         const Expression& arg = *c.fArguments[i];
    159         this->write(separator);
    160         separator = ", ";
    161         if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) {
    162             this->write("&");
    163         }
    164         this->writeExpression(arg, kSequence_Precedence);
    165     }
    166     this->write(")");
    167 }
    168 
    169 void MetalCodeGenerator::writeConstructor(const Constructor& c) {
    170     this->writeType(c.fType);
    171     this->write("(");
    172     const char* separator = "";
    173     int scalarCount = 0;
    174     for (const auto& arg : c.fArguments) {
    175         this->write(separator);
    176         separator = ", ";
    177         if (Type::kMatrix_Kind == c.fType.kind() && Type::kScalar_Kind == arg->fType.kind()) {
    178             // float2x2(float, float, float, float) doesn't work in Metal 1, so we need to merge to
    179             // float2x2(float2, float2).
    180             if (!scalarCount) {
    181                 this->writeType(c.fType.componentType());
    182                 this->write(to_string(c.fType.rows()));
    183                 this->write("(");
    184             }
    185             ++scalarCount;
    186         }
    187         this->writeExpression(*arg, kSequence_Precedence);
    188         if (scalarCount && scalarCount == c.fType.rows()) {
    189             this->write(")");
    190             scalarCount = 0;
    191         }
    192     }
    193     this->write(")");
    194 }
    195 
    196 void MetalCodeGenerator::writeFragCoord() {
    197     this->write("_in.position");
    198 }
    199 
    200 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) {
    201     switch (ref.fVariable.fModifiers.fLayout.fBuiltin) {
    202         case SK_FRAGCOLOR_BUILTIN:
    203             this->write("sk_FragColor");
    204             break;
    205         default:
    206             if (Variable::kGlobal_Storage == ref.fVariable.fStorage) {
    207                 if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
    208                     this->write("_in.");
    209                 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
    210                     this->write("_out.");
    211                 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) {
    212                     this->write("_uniforms.");
    213                 } else {
    214                     fErrors.error(ref.fVariable.fOffset, "Metal backend does not support global "
    215                                   "variables");
    216                 }
    217             }
    218             this->write(ref.fVariable.fName);
    219     }
    220 }
    221 
    222 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) {
    223     this->writeExpression(*expr.fBase, kPostfix_Precedence);
    224     this->write("[");
    225     this->writeExpression(*expr.fIndex, kTopLevel_Precedence);
    226     this->write("]");
    227 }
    228 
    229 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) {
    230     if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) {
    231         this->writeExpression(*f.fBase, kPostfix_Precedence);
    232         this->write(".");
    233     }
    234     switch (f.fBase->fType.fields()[f.fFieldIndex].fModifiers.fLayout.fBuiltin) {
    235         case SK_CLIPDISTANCE_BUILTIN:
    236             this->write("gl_ClipDistance");
    237             break;
    238         case SK_POSITION_BUILTIN:
    239             this->write("_out.position");
    240             break;
    241         default:
    242             this->write(f.fBase->fType.fields()[f.fFieldIndex].fName);
    243     }
    244 }
    245 
    246 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) {
    247     this->writeExpression(*swizzle.fBase, kPostfix_Precedence);
    248     this->write(".");
    249     for (int c : swizzle.fComponents) {
    250         this->write(&("x\0y\0z\0w\0"[c * 2]));
    251     }
    252 }
    253 
    254 MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) {
    255     switch (op) {
    256         case Token::STAR:         // fall through
    257         case Token::SLASH:        // fall through
    258         case Token::PERCENT:      return MetalCodeGenerator::kMultiplicative_Precedence;
    259         case Token::PLUS:         // fall through
    260         case Token::MINUS:        return MetalCodeGenerator::kAdditive_Precedence;
    261         case Token::SHL:          // fall through
    262         case Token::SHR:          return MetalCodeGenerator::kShift_Precedence;
    263         case Token::LT:           // fall through
    264         case Token::GT:           // fall through
    265         case Token::LTEQ:         // fall through
    266         case Token::GTEQ:         return MetalCodeGenerator::kRelational_Precedence;
    267         case Token::EQEQ:         // fall through
    268         case Token::NEQ:          return MetalCodeGenerator::kEquality_Precedence;
    269         case Token::BITWISEAND:   return MetalCodeGenerator::kBitwiseAnd_Precedence;
    270         case Token::BITWISEXOR:   return MetalCodeGenerator::kBitwiseXor_Precedence;
    271         case Token::BITWISEOR:    return MetalCodeGenerator::kBitwiseOr_Precedence;
    272         case Token::LOGICALAND:   return MetalCodeGenerator::kLogicalAnd_Precedence;
    273         case Token::LOGICALXOR:   return MetalCodeGenerator::kLogicalXor_Precedence;
    274         case Token::LOGICALOR:    return MetalCodeGenerator::kLogicalOr_Precedence;
    275         case Token::EQ:           // fall through
    276         case Token::PLUSEQ:       // fall through
    277         case Token::MINUSEQ:      // fall through
    278         case Token::STAREQ:       // fall through
    279         case Token::SLASHEQ:      // fall through
    280         case Token::PERCENTEQ:    // fall through
    281         case Token::SHLEQ:        // fall through
    282         case Token::SHREQ:        // fall through
    283         case Token::LOGICALANDEQ: // fall through
    284         case Token::LOGICALXOREQ: // fall through
    285         case Token::LOGICALOREQ:  // fall through
    286         case Token::BITWISEANDEQ: // fall through
    287         case Token::BITWISEXOREQ: // fall through
    288         case Token::BITWISEOREQ:  return MetalCodeGenerator::kAssignment_Precedence;
    289         case Token::COMMA:        return MetalCodeGenerator::kSequence_Precedence;
    290         default: ABORT("unsupported binary operator");
    291     }
    292 }
    293 
    294 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b,
    295                                                Precedence parentPrecedence) {
    296     Precedence precedence = GetBinaryPrecedence(b.fOperator);
    297     if (precedence >= parentPrecedence) {
    298         this->write("(");
    299     }
    300     if (Compiler::IsAssignment(b.fOperator) &&
    301         Expression::kVariableReference_Kind == b.fLeft->fKind &&
    302         Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage &&
    303         (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
    304         // writing to an out parameter. Since we have to turn those into pointers, we have to
    305         // dereference it here.
    306         this->write("*");
    307     }
    308     this->writeExpression(*b.fLeft, precedence);
    309     if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) &&
    310         Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) {
    311         // This doesn't compile in Metal:
    312         // float4 x = float4(1);
    313         // x.xy *= float2x2(...);
    314         // with the error message "non-const reference cannot bind to vector element",
    315         // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation
    316         // as long as the LHS has no side effects, and hope for the best otherwise.
    317         this->write(" = ");
    318         this->writeExpression(*b.fLeft, kAssignment_Precedence);
    319         this->write(" ");
    320         String op = Compiler::OperatorName(b.fOperator);
    321         ASSERT(op.endsWith("="));
    322         this->write(op.substr(0, op.size() - 1).c_str());
    323         this->write(" ");
    324     } else {
    325         this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " ");
    326     }
    327     this->writeExpression(*b.fRight, precedence);
    328     if (precedence >= parentPrecedence) {
    329         this->write(")");
    330     }
    331 }
    332 
    333 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t,
    334                                                Precedence parentPrecedence) {
    335     if (kTernary_Precedence >= parentPrecedence) {
    336         this->write("(");
    337     }
    338     this->writeExpression(*t.fTest, kTernary_Precedence);
    339     this->write(" ? ");
    340     this->writeExpression(*t.fIfTrue, kTernary_Precedence);
    341     this->write(" : ");
    342     this->writeExpression(*t.fIfFalse, kTernary_Precedence);
    343     if (kTernary_Precedence >= parentPrecedence) {
    344         this->write(")");
    345     }
    346 }
    347 
    348 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p,
    349                                               Precedence parentPrecedence) {
    350     if (kPrefix_Precedence >= parentPrecedence) {
    351         this->write("(");
    352     }
    353     this->write(Compiler::OperatorName(p.fOperator));
    354     this->writeExpression(*p.fOperand, kPrefix_Precedence);
    355     if (kPrefix_Precedence >= parentPrecedence) {
    356         this->write(")");
    357     }
    358 }
    359 
    360 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p,
    361                                                Precedence parentPrecedence) {
    362     if (kPostfix_Precedence >= parentPrecedence) {
    363         this->write("(");
    364     }
    365     this->writeExpression(*p.fOperand, kPostfix_Precedence);
    366     this->write(Compiler::OperatorName(p.fOperator));
    367     if (kPostfix_Precedence >= parentPrecedence) {
    368         this->write(")");
    369     }
    370 }
    371 
    372 void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
    373     this->write(b.fValue ? "true" : "false");
    374 }
    375 
    376 void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) {
    377     if (i.fType == *fContext.fUInt_Type) {
    378         this->write(to_string(i.fValue & 0xffffffff) + "u");
    379     } else {
    380         this->write(to_string((int32_t) i.fValue));
    381     }
    382 }
    383 
    384 void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
    385     this->write(to_string(f.fValue));
    386 }
    387 
    388 void MetalCodeGenerator::writeSetting(const Setting& s) {
    389     ABORT("internal error; setting was not folded to a constant during compilation\n");
    390 }
    391 
    392 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) {
    393     const char* separator = "";
    394     if ("main" == f.fDeclaration.fName) {
    395         switch (fProgram.fKind) {
    396             case Program::kFragment_Kind:
    397                 this->write("fragment half4 _frag");
    398                 break;
    399             case Program::kVertex_Kind:
    400                 this->write("vertex Outputs _vert");
    401                 break;
    402             default:
    403                 ASSERT(false);
    404         }
    405         this->write("(Inputs _in [[stage_in]]");
    406         if (-1 != fUniformBuffer) {
    407             this->write(", constant Uniforms& _uniforms [[buffer(" +
    408                         to_string(fUniformBuffer) + ")]]");
    409         }
    410         separator = ", ";
    411     } else {
    412         this->writeType(f.fDeclaration.fReturnType);
    413         this->write(" " + f.fDeclaration.fName + "(");
    414         if (this->requirements(f.fDeclaration) & kInputs_Requirement) {
    415             this->write("Inputs _in");
    416             separator = ", ";
    417         }
    418         if (this->requirements(f.fDeclaration) & kOutputs_Requirement) {
    419             this->write(separator);
    420             this->write("thread Outputs& _out");
    421             separator = ", ";
    422         }
    423         if (this->requirements(f.fDeclaration) & kUniforms_Requirement) {
    424             this->write(separator);
    425             this->write("Uniforms _uniforms");
    426             separator = ", ";
    427         }
    428     }
    429     for (const auto& param : f.fDeclaration.fParameters) {
    430         this->write(separator);
    431         separator = ", ";
    432         this->writeModifiers(param->fModifiers, false);
    433         std::vector<int> sizes;
    434         const Type* type = &param->fType;
    435         while (Type::kArray_Kind == type->kind()) {
    436             sizes.push_back(type->columns());
    437             type = &type->componentType();
    438         }
    439         this->writeType(*type);
    440         if (param->fModifiers.fFlags & Modifiers::kOut_Flag) {
    441             this->write("*");
    442         }
    443         this->write(" " + param->fName);
    444         for (int s : sizes) {
    445             if (s <= 0) {
    446                 this->write("[]");
    447             } else {
    448                 this->write("[" + to_string(s) + "]");
    449             }
    450         }
    451     }
    452     this->writeLine(") {");
    453 
    454     ASSERT(!fProgram.fSettings.fFragColorIsInOut);
    455 
    456     if ("main" == f.fDeclaration.fName) {
    457         switch (fProgram.fKind) {
    458             case Program::kFragment_Kind:
    459                 this->writeLine("    half4 sk_FragColor;");
    460                 break;
    461             case Program::kVertex_Kind:
    462                 this->writeLine("    Outputs _out;");
    463                 break;
    464             default:
    465                 ASSERT(false);
    466         }
    467     }
    468     fFunctionHeader = "";
    469     OutputStream* oldOut = fOut;
    470     StringStream buffer;
    471     fOut = &buffer;
    472     fIndentation++;
    473     this->writeStatements(((Block&) *f.fBody).fStatements);
    474     if ("main" == f.fDeclaration.fName) {
    475         switch (fProgram.fKind) {
    476             case Program::kFragment_Kind:
    477                 this->writeLine("return sk_FragColor;");
    478                 break;
    479             case Program::kVertex_Kind:
    480                 this->writeLine("return _out;");
    481                 break;
    482             default:
    483                 ASSERT(false);
    484         }
    485     }
    486     fIndentation--;
    487     this->writeLine("}");
    488 
    489     fOut = oldOut;
    490     this->write(fFunctionHeader);
    491     this->write(buffer.str());
    492 }
    493 
    494 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers,
    495                                        bool globalContext) {
    496     if (modifiers.fFlags & Modifiers::kOut_Flag) {
    497         this->write("thread ");
    498     }
    499     if (modifiers.fFlags & Modifiers::kConst_Flag) {
    500         this->write("const ");
    501     }
    502 }
    503 
    504 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
    505     if ("sk_PerVertex" == intf.fTypeName) {
    506         return;
    507     }
    508     this->writeModifiers(intf.fVariable.fModifiers, true);
    509     this->writeLine(intf.fTypeName + " {");
    510     fIndentation++;
    511     const Type* structType = &intf.fVariable.fType;
    512     while (Type::kArray_Kind == structType->kind()) {
    513         structType = &structType->componentType();
    514     }
    515     for (const auto& f : structType->fields()) {
    516         this->writeModifiers(f.fModifiers, false);
    517         this->writeType(*f.fType);
    518         this->writeLine(" " + f.fName + ";");
    519     }
    520     fIndentation--;
    521     this->write("}");
    522     if (intf.fInstanceName.size()) {
    523         this->write(" ");
    524         this->write(intf.fInstanceName);
    525         for (const auto& size : intf.fSizes) {
    526             this->write("[");
    527             if (size) {
    528                 this->writeExpression(*size, kTopLevel_Precedence);
    529             }
    530             this->write("]");
    531         }
    532     }
    533     this->writeLine(";");
    534 }
    535 
    536 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) {
    537     this->writeExpression(value, kTopLevel_Precedence);
    538 }
    539 
    540 void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) {
    541     ASSERT(decl.fVars.size() > 0);
    542     bool wroteType = false;
    543     for (const auto& stmt : decl.fVars) {
    544         VarDeclaration& var = (VarDeclaration&) *stmt;
    545         if (var.fVar->fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag |
    546                                            Modifiers::kUniform_Flag)) {
    547             ASSERT(global);
    548             continue;
    549         }
    550         if (wroteType) {
    551             this->write(", ");
    552         } else {
    553             this->writeModifiers(var.fVar->fModifiers, global);
    554             this->writeType(decl.fBaseType);
    555             this->write(" ");
    556             wroteType = true;
    557         }
    558         this->write(var.fVar->fName);
    559         for (const auto& size : var.fSizes) {
    560             this->write("[");
    561             if (size) {
    562                 this->writeExpression(*size, kTopLevel_Precedence);
    563             }
    564             this->write("]");
    565         }
    566         if (var.fValue) {
    567             this->write(" = ");
    568             this->writeVarInitializer(*var.fVar, *var.fValue);
    569         }
    570         if (!fFoundImageDecl && var.fVar->fType == *fContext.fImage2D_Type) {
    571             if (fProgram.fSettings.fCaps->imageLoadStoreExtensionString()) {
    572                 fHeader.writeText("#extension ");
    573                 fHeader.writeText(fProgram.fSettings.fCaps->imageLoadStoreExtensionString());
    574                 fHeader.writeText(" : require\n");
    575             }
    576             fFoundImageDecl = true;
    577         }
    578     }
    579     if (wroteType) {
    580         this->write(";");
    581     }
    582 }
    583 
    584 void MetalCodeGenerator::writeStatement(const Statement& s) {
    585     switch (s.fKind) {
    586         case Statement::kBlock_Kind:
    587             this->writeBlock((Block&) s);
    588             break;
    589         case Statement::kExpression_Kind:
    590             this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence);
    591             this->write(";");
    592             break;
    593         case Statement::kReturn_Kind:
    594             this->writeReturnStatement((ReturnStatement&) s);
    595             break;
    596         case Statement::kVarDeclarations_Kind:
    597             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false);
    598             break;
    599         case Statement::kIf_Kind:
    600             this->writeIfStatement((IfStatement&) s);
    601             break;
    602         case Statement::kFor_Kind:
    603             this->writeForStatement((ForStatement&) s);
    604             break;
    605         case Statement::kWhile_Kind:
    606             this->writeWhileStatement((WhileStatement&) s);
    607             break;
    608         case Statement::kDo_Kind:
    609             this->writeDoStatement((DoStatement&) s);
    610             break;
    611         case Statement::kSwitch_Kind:
    612             this->writeSwitchStatement((SwitchStatement&) s);
    613             break;
    614         case Statement::kBreak_Kind:
    615             this->write("break;");
    616             break;
    617         case Statement::kContinue_Kind:
    618             this->write("continue;");
    619             break;
    620         case Statement::kDiscard_Kind:
    621             this->write("discard;");
    622             break;
    623         case Statement::kNop_Kind:
    624             this->write(";");
    625             break;
    626         default:
    627             ABORT("unsupported statement: %s", s.description().c_str());
    628     }
    629 }
    630 
    631 void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) {
    632     for (const auto& s : statements) {
    633         if (!s->isEmpty()) {
    634             this->writeStatement(*s);
    635             this->writeLine();
    636         }
    637     }
    638 }
    639 
    640 void MetalCodeGenerator::writeBlock(const Block& b) {
    641     this->writeLine("{");
    642     fIndentation++;
    643     this->writeStatements(b.fStatements);
    644     fIndentation--;
    645     this->write("}");
    646 }
    647 
    648 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) {
    649     this->write("if (");
    650     this->writeExpression(*stmt.fTest, kTopLevel_Precedence);
    651     this->write(") ");
    652     this->writeStatement(*stmt.fIfTrue);
    653     if (stmt.fIfFalse) {
    654         this->write(" else ");
    655         this->writeStatement(*stmt.fIfFalse);
    656     }
    657 }
    658 
    659 void MetalCodeGenerator::writeForStatement(const ForStatement& f) {
    660     this->write("for (");
    661     if (f.fInitializer && !f.fInitializer->isEmpty()) {
    662         this->writeStatement(*f.fInitializer);
    663     } else {
    664         this->write("; ");
    665     }
    666     if (f.fTest) {
    667         this->writeExpression(*f.fTest, kTopLevel_Precedence);
    668     }
    669     this->write("; ");
    670     if (f.fNext) {
    671         this->writeExpression(*f.fNext, kTopLevel_Precedence);
    672     }
    673     this->write(") ");
    674     this->writeStatement(*f.fStatement);
    675 }
    676 
    677 void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) {
    678     this->write("while (");
    679     this->writeExpression(*w.fTest, kTopLevel_Precedence);
    680     this->write(") ");
    681     this->writeStatement(*w.fStatement);
    682 }
    683 
    684 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) {
    685     this->write("do ");
    686     this->writeStatement(*d.fStatement);
    687     this->write(" while (");
    688     this->writeExpression(*d.fTest, kTopLevel_Precedence);
    689     this->write(");");
    690 }
    691 
    692 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) {
    693     this->write("switch (");
    694     this->writeExpression(*s.fValue, kTopLevel_Precedence);
    695     this->writeLine(") {");
    696     fIndentation++;
    697     for (const auto& c : s.fCases) {
    698         if (c->fValue) {
    699             this->write("case ");
    700             this->writeExpression(*c->fValue, kTopLevel_Precedence);
    701             this->writeLine(":");
    702         } else {
    703             this->writeLine("default:");
    704         }
    705         fIndentation++;
    706         for (const auto& stmt : c->fStatements) {
    707             this->writeStatement(*stmt);
    708             this->writeLine();
    709         }
    710         fIndentation--;
    711     }
    712     fIndentation--;
    713     this->write("}");
    714 }
    715 
    716 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) {
    717     this->write("return");
    718     if (r.fExpression) {
    719         this->write(" ");
    720         this->writeExpression(*r.fExpression, kTopLevel_Precedence);
    721     }
    722     this->write(";");
    723 }
    724 
    725 void MetalCodeGenerator::writeHeader() {
    726     this->write("#include <metal_stdlib>\n");
    727     this->write("#include <simd/simd.h>\n");
    728     this->write("using namespace metal;\n");
    729 }
    730 
    731 void MetalCodeGenerator::writeUniformStruct() {
    732     for (const auto& e : fProgram.fElements) {
    733         if (ProgramElement::kVar_Kind == e->fKind) {
    734             VarDeclarations& decls = (VarDeclarations&) *e;
    735             if (!decls.fVars.size()) {
    736                 continue;
    737             }
    738             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
    739             if (first.fModifiers.fFlags & Modifiers::kUniform_Flag) {
    740                 if (-1 == fUniformBuffer) {
    741                     this->write("struct Uniforms {\n");
    742                     fUniformBuffer = first.fModifiers.fLayout.fSet;
    743                     if (-1 == fUniformBuffer) {
    744                         fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'");
    745                     }
    746                 } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) {
    747                     if (-1 == fUniformBuffer) {
    748                         fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have "
    749                                     "the same 'layout(set=...)'");
    750                     }
    751                 }
    752                 this->write("    ");
    753                 this->writeType(first.fType);
    754                 this->write(" ");
    755                 for (const auto& stmt : decls.fVars) {
    756                     VarDeclaration& var = (VarDeclaration&) *stmt;
    757                     this->write(var.fVar->fName);
    758                 }
    759                 this->write(";\n");
    760             }
    761         }
    762     }
    763     if (-1 != fUniformBuffer) {
    764         this->write("};\n");
    765     }
    766 }
    767 
    768 void MetalCodeGenerator::writeInputStruct() {
    769     this->write("struct Inputs {\n");
    770     if (Program::kFragment_Kind == fProgram.fKind) {
    771         this->write("    float4 position [[position]];\n");
    772     }
    773     for (const auto& e : fProgram.fElements) {
    774         if (ProgramElement::kVar_Kind == e->fKind) {
    775             VarDeclarations& decls = (VarDeclarations&) *e;
    776             if (!decls.fVars.size()) {
    777                 continue;
    778             }
    779             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
    780             if (first.fModifiers.fFlags & Modifiers::kIn_Flag &&
    781                 -1 == first.fModifiers.fLayout.fBuiltin) {
    782                 this->write("    ");
    783                 this->writeType(first.fType);
    784                 this->write(" ");
    785                 for (const auto& stmt : decls.fVars) {
    786                     VarDeclaration& var = (VarDeclaration&) *stmt;
    787                     this->write(var.fVar->fName);
    788                     if (-1 != var.fVar->fModifiers.fLayout.fLocation) {
    789                         this->write("  [[attribute(" +
    790                                     to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]");
    791                     }
    792                 }
    793                 this->write(";\n");
    794             }
    795         }
    796     }
    797     this->write("};\n");
    798 }
    799 
    800 void MetalCodeGenerator::writeOutputStruct() {
    801     this->write("struct Outputs {\n");
    802     this->write("    float4 position [[position]];\n");
    803     for (const auto& e : fProgram.fElements) {
    804         if (ProgramElement::kVar_Kind == e->fKind) {
    805             VarDeclarations& decls = (VarDeclarations&) *e;
    806             if (!decls.fVars.size()) {
    807                 continue;
    808             }
    809             const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar;
    810             if (first.fModifiers.fFlags & Modifiers::kOut_Flag &&
    811                 -1 == first.fModifiers.fLayout.fBuiltin) {
    812                 this->write("    ");
    813                 this->writeType(first.fType);
    814                 this->write(" ");
    815                 for (const auto& stmt : decls.fVars) {
    816                     VarDeclaration& var = (VarDeclaration&) *stmt;
    817                     this->write(var.fVar->fName);
    818                 }
    819                 this->write(";\n");
    820             }
    821         }
    822     }    this->write("};\n");
    823 }
    824 
    825 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) {
    826     switch (e.fKind) {
    827         case ProgramElement::kExtension_Kind:
    828             break;
    829         case ProgramElement::kVar_Kind: {
    830             VarDeclarations& decl = (VarDeclarations&) e;
    831             if (decl.fVars.size() > 0) {
    832                 int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin;
    833                 if (-1 == builtin) {
    834                     // normal var
    835                     this->writeVarDeclarations(decl, true);
    836                     this->writeLine();
    837                 } else if (SK_FRAGCOLOR_BUILTIN == builtin) {
    838                     // ignore
    839                 }
    840             }
    841             break;
    842         }
    843         case ProgramElement::kInterfaceBlock_Kind:
    844             this->writeInterfaceBlock((InterfaceBlock&) e);
    845             break;
    846         case ProgramElement::kFunction_Kind:
    847             this->writeFunction((FunctionDefinition&) e);
    848             break;
    849         case ProgramElement::kModifiers_Kind:
    850             this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true);
    851             this->writeLine(";");
    852             break;
    853         default:
    854             printf("%s\n", e.description().c_str());
    855             ABORT("unsupported program element");
    856     }
    857 }
    858 
    859 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression& e) {
    860     switch (e.fKind) {
    861         case Expression::kFunctionCall_Kind: {
    862             const FunctionCall& f = (const FunctionCall&) e;
    863             Requirements result = this->requirements(f.fFunction);
    864             for (const auto& e : f.fArguments) {
    865                 result |= this->requirements(*e);
    866             }
    867             return result;
    868         }
    869         case Expression::kConstructor_Kind: {
    870             const Constructor& c = (const Constructor&) e;
    871             Requirements result = kNo_Requirements;
    872             for (const auto& e : c.fArguments) {
    873                 result |= this->requirements(*e);
    874             }
    875             return result;
    876         }
    877         case Expression::kFieldAccess_Kind:
    878             return this->requirements(*((const FieldAccess&) e).fBase);
    879         case Expression::kSwizzle_Kind:
    880             return this->requirements(*((const Swizzle&) e).fBase);
    881         case Expression::kBinary_Kind: {
    882             const BinaryExpression& b = (const BinaryExpression&) e;
    883             return this->requirements(*b.fLeft) | this->requirements(*b.fRight);
    884         }
    885         case Expression::kIndex_Kind: {
    886             const IndexExpression& idx = (const IndexExpression&) e;
    887             return this->requirements(*idx.fBase) | this->requirements(*idx.fIndex);
    888         }
    889         case Expression::kPrefix_Kind:
    890             return this->requirements(*((const PrefixExpression&) e).fOperand);
    891         case Expression::kPostfix_Kind:
    892             return this->requirements(*((const PostfixExpression&) e).fOperand);
    893         case Expression::kTernary_Kind: {
    894             const TernaryExpression& t = (const TernaryExpression&) e;
    895             return this->requirements(*t.fTest) | this->requirements(*t.fIfTrue) |
    896                    this->requirements(*t.fIfFalse);
    897         }
    898         case Expression::kVariableReference_Kind: {
    899             const VariableReference& v = (const VariableReference&) e;
    900             Requirements result = kNo_Requirements;
    901             if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) {
    902                 result = kInputs_Requirement;
    903             } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) {
    904                 if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) {
    905                     result = kInputs_Requirement;
    906                 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) {
    907                     result = kOutputs_Requirement;
    908                 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) {
    909                     result = kUniforms_Requirement;
    910                 }
    911             }
    912             return result;
    913         }
    914         default:
    915             return kNo_Requirements;
    916     }
    917 }
    918 
    919 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement& s) {
    920     switch (s.fKind) {
    921         case Statement::kBlock_Kind: {
    922             Requirements result = kNo_Requirements;
    923             for (const auto& child : ((const Block&) s).fStatements) {
    924                 result |= this->requirements(*child);
    925             }
    926             return result;
    927         }
    928         case Statement::kExpression_Kind:
    929             return this->requirements(*((const ExpressionStatement&) s).fExpression);
    930         case Statement::kReturn_Kind: {
    931             const ReturnStatement& r = (const ReturnStatement&) s;
    932             if (r.fExpression) {
    933                 return this->requirements(*r.fExpression);
    934             }
    935             return kNo_Requirements;
    936         }
    937         case Statement::kIf_Kind: {
    938             const IfStatement& i = (const IfStatement&) s;
    939             return this->requirements(*i.fTest) |
    940                    this->requirements(*i.fIfTrue) |
    941                    (i.fIfFalse && this->requirements(*i.fIfFalse));
    942         }
    943         case Statement::kFor_Kind: {
    944             const ForStatement& f = (const ForStatement&) s;
    945             return this->requirements(*f.fInitializer) |
    946                    this->requirements(*f.fTest) |
    947                    this->requirements(*f.fNext) |
    948                    this->requirements(*f.fStatement);
    949         }
    950         case Statement::kWhile_Kind: {
    951             const WhileStatement& w = (const WhileStatement&) s;
    952             return this->requirements(*w.fTest) |
    953                    this->requirements(*w.fStatement);
    954         }
    955         case Statement::kDo_Kind: {
    956             const DoStatement& d = (const DoStatement&) s;
    957             return this->requirements(*d.fTest) |
    958                    this->requirements(*d.fStatement);
    959         }
    960         case Statement::kSwitch_Kind: {
    961             const SwitchStatement& sw = (const SwitchStatement&) s;
    962             Requirements result = this->requirements(*sw.fValue);
    963             for (const auto& c : sw.fCases) {
    964                 for (const auto& st : c->fStatements) {
    965                     result |= this->requirements(*st);
    966                 }
    967             }
    968             return result;
    969         }
    970         default:
    971             return kNo_Requirements;
    972     }
    973 }
    974 
    975 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) {
    976     if (f.fBuiltin) {
    977         return kNo_Requirements;
    978     }
    979     auto found = fRequirements.find(&f);
    980     if (found == fRequirements.end()) {
    981         for (const auto& e : fProgram.fElements) {
    982             if (ProgramElement::kFunction_Kind == e->fKind) {
    983                 const FunctionDefinition& def = (const FunctionDefinition&) *e;
    984                 if (&def.fDeclaration == &f) {
    985                     Requirements reqs = this->requirements(*def.fBody);
    986                     fRequirements[&f] = reqs;
    987                     return reqs;
    988                 }
    989             }
    990         }
    991     }
    992     return found->second;
    993 }
    994 
    995 bool MetalCodeGenerator::generateCode() {
    996     OutputStream* rawOut = fOut;
    997     fOut = &fHeader;
    998     fProgramKind = fProgram.fKind;
    999     this->writeHeader();
   1000     this->writeUniformStruct();
   1001     this->writeInputStruct();
   1002     if (Program::kVertex_Kind == fProgram.fKind) {
   1003         this->writeOutputStruct();
   1004     }
   1005     StringStream body;
   1006     fOut = &body;
   1007     for (const auto& e : fProgram.fElements) {
   1008         this->writeProgramElement(*e);
   1009     }
   1010     fOut = rawOut;
   1011 
   1012     write_stringstream(fHeader, *rawOut);
   1013     write_stringstream(body, *rawOut);
   1014     return true;
   1015 }
   1016 
   1017 }
   1018