1 /* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8 #include "SkSLMetalCodeGenerator.h" 9 10 #include "SkSLCompiler.h" 11 #include "ir/SkSLExpressionStatement.h" 12 #include "ir/SkSLExtension.h" 13 #include "ir/SkSLIndexExpression.h" 14 #include "ir/SkSLModifiersDeclaration.h" 15 #include "ir/SkSLNop.h" 16 #include "ir/SkSLVariableReference.h" 17 18 namespace SkSL { 19 20 void MetalCodeGenerator::write(const char* s) { 21 if (!s[0]) { 22 return; 23 } 24 if (fAtLineStart) { 25 for (int i = 0; i < fIndentation; i++) { 26 fOut->writeText(" "); 27 } 28 } 29 fOut->writeText(s); 30 fAtLineStart = false; 31 } 32 33 void MetalCodeGenerator::writeLine(const char* s) { 34 this->write(s); 35 fOut->writeText(fLineEnding); 36 fAtLineStart = true; 37 } 38 39 void MetalCodeGenerator::write(const String& s) { 40 this->write(s.c_str()); 41 } 42 43 void MetalCodeGenerator::writeLine(const String& s) { 44 this->writeLine(s.c_str()); 45 } 46 47 void MetalCodeGenerator::writeLine() { 48 this->writeLine(""); 49 } 50 51 void MetalCodeGenerator::writeExtension(const Extension& ext) { 52 this->writeLine("#extension " + ext.fName + " : enable"); 53 } 54 55 void MetalCodeGenerator::writeType(const Type& type) { 56 switch (type.kind()) { 57 case Type::kStruct_Kind: 58 for (const Type* search : fWrittenStructs) { 59 if (*search == type) { 60 // already written 61 this->write(type.name()); 62 return; 63 } 64 } 65 fWrittenStructs.push_back(&type); 66 this->writeLine("struct " + type.name() + " {"); 67 fIndentation++; 68 for (const auto& f : type.fields()) { 69 this->writeModifiers(f.fModifiers, false); 70 // sizes (which must be static in structs) are part of the type name here 71 this->writeType(*f.fType); 72 this->writeLine(" " + f.fName + ";"); 73 } 74 fIndentation--; 75 this->write("}"); 76 break; 77 case Type::kVector_Kind: 78 this->writeType(type.componentType()); 79 this->write(to_string(type.columns())); 80 break; 81 default: 82 this->write(type.name()); 83 } 84 } 85 86 void MetalCodeGenerator::writeExpression(const Expression& expr, Precedence parentPrecedence) { 87 switch (expr.fKind) { 88 case Expression::kBinary_Kind: 89 this->writeBinaryExpression((BinaryExpression&) expr, parentPrecedence); 90 break; 91 case Expression::kBoolLiteral_Kind: 92 this->writeBoolLiteral((BoolLiteral&) expr); 93 break; 94 case Expression::kConstructor_Kind: 95 this->writeConstructor((Constructor&) expr); 96 break; 97 case Expression::kIntLiteral_Kind: 98 this->writeIntLiteral((IntLiteral&) expr); 99 break; 100 case Expression::kFieldAccess_Kind: 101 this->writeFieldAccess(((FieldAccess&) expr)); 102 break; 103 case Expression::kFloatLiteral_Kind: 104 this->writeFloatLiteral(((FloatLiteral&) expr)); 105 break; 106 case Expression::kFunctionCall_Kind: 107 this->writeFunctionCall((FunctionCall&) expr); 108 break; 109 case Expression::kPrefix_Kind: 110 this->writePrefixExpression((PrefixExpression&) expr, parentPrecedence); 111 break; 112 case Expression::kPostfix_Kind: 113 this->writePostfixExpression((PostfixExpression&) expr, parentPrecedence); 114 break; 115 case Expression::kSetting_Kind: 116 this->writeSetting((Setting&) expr); 117 break; 118 case Expression::kSwizzle_Kind: 119 this->writeSwizzle((Swizzle&) expr); 120 break; 121 case Expression::kVariableReference_Kind: 122 this->writeVariableReference((VariableReference&) expr); 123 break; 124 case Expression::kTernary_Kind: 125 this->writeTernaryExpression((TernaryExpression&) expr, parentPrecedence); 126 break; 127 case Expression::kIndex_Kind: 128 this->writeIndexExpression((IndexExpression&) expr); 129 break; 130 default: 131 ABORT("unsupported expression: %s", expr.description().c_str()); 132 } 133 } 134 135 void MetalCodeGenerator::writeFunctionCall(const FunctionCall& c) { 136 if (c.fFunction.fBuiltin && "atan" == c.fFunction.fName && 2 == c.fArguments.size()) { 137 this->write("atan2"); 138 } else { 139 this->write(c.fFunction.fName); 140 } 141 this->write("("); 142 const char* separator = ""; 143 if (this->requirements(c.fFunction) & kInputs_Requirement) { 144 this->write("_in"); 145 separator = ", "; 146 } 147 if (this->requirements(c.fFunction) & kOutputs_Requirement) { 148 this->write(separator); 149 this->write("_out"); 150 separator = ", "; 151 } 152 if (this->requirements(c.fFunction) & kUniforms_Requirement) { 153 this->write(separator); 154 this->write("_uniforms"); 155 separator = ", "; 156 } 157 for (size_t i = 0; i < c.fArguments.size(); ++i) { 158 const Expression& arg = *c.fArguments[i]; 159 this->write(separator); 160 separator = ", "; 161 if (c.fFunction.fParameters[i]->fModifiers.fFlags & Modifiers::kOut_Flag) { 162 this->write("&"); 163 } 164 this->writeExpression(arg, kSequence_Precedence); 165 } 166 this->write(")"); 167 } 168 169 void MetalCodeGenerator::writeConstructor(const Constructor& c) { 170 this->writeType(c.fType); 171 this->write("("); 172 const char* separator = ""; 173 int scalarCount = 0; 174 for (const auto& arg : c.fArguments) { 175 this->write(separator); 176 separator = ", "; 177 if (Type::kMatrix_Kind == c.fType.kind() && Type::kScalar_Kind == arg->fType.kind()) { 178 // float2x2(float, float, float, float) doesn't work in Metal 1, so we need to merge to 179 // float2x2(float2, float2). 180 if (!scalarCount) { 181 this->writeType(c.fType.componentType()); 182 this->write(to_string(c.fType.rows())); 183 this->write("("); 184 } 185 ++scalarCount; 186 } 187 this->writeExpression(*arg, kSequence_Precedence); 188 if (scalarCount && scalarCount == c.fType.rows()) { 189 this->write(")"); 190 scalarCount = 0; 191 } 192 } 193 this->write(")"); 194 } 195 196 void MetalCodeGenerator::writeFragCoord() { 197 this->write("_in.position"); 198 } 199 200 void MetalCodeGenerator::writeVariableReference(const VariableReference& ref) { 201 switch (ref.fVariable.fModifiers.fLayout.fBuiltin) { 202 case SK_FRAGCOLOR_BUILTIN: 203 this->write("sk_FragColor"); 204 break; 205 default: 206 if (Variable::kGlobal_Storage == ref.fVariable.fStorage) { 207 if (ref.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) { 208 this->write("_in."); 209 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) { 210 this->write("_out."); 211 } else if (ref.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) { 212 this->write("_uniforms."); 213 } else { 214 fErrors.error(ref.fVariable.fOffset, "Metal backend does not support global " 215 "variables"); 216 } 217 } 218 this->write(ref.fVariable.fName); 219 } 220 } 221 222 void MetalCodeGenerator::writeIndexExpression(const IndexExpression& expr) { 223 this->writeExpression(*expr.fBase, kPostfix_Precedence); 224 this->write("["); 225 this->writeExpression(*expr.fIndex, kTopLevel_Precedence); 226 this->write("]"); 227 } 228 229 void MetalCodeGenerator::writeFieldAccess(const FieldAccess& f) { 230 if (FieldAccess::kDefault_OwnerKind == f.fOwnerKind) { 231 this->writeExpression(*f.fBase, kPostfix_Precedence); 232 this->write("."); 233 } 234 switch (f.fBase->fType.fields()[f.fFieldIndex].fModifiers.fLayout.fBuiltin) { 235 case SK_CLIPDISTANCE_BUILTIN: 236 this->write("gl_ClipDistance"); 237 break; 238 case SK_POSITION_BUILTIN: 239 this->write("_out.position"); 240 break; 241 default: 242 this->write(f.fBase->fType.fields()[f.fFieldIndex].fName); 243 } 244 } 245 246 void MetalCodeGenerator::writeSwizzle(const Swizzle& swizzle) { 247 this->writeExpression(*swizzle.fBase, kPostfix_Precedence); 248 this->write("."); 249 for (int c : swizzle.fComponents) { 250 this->write(&("x\0y\0z\0w\0"[c * 2])); 251 } 252 } 253 254 MetalCodeGenerator::Precedence MetalCodeGenerator::GetBinaryPrecedence(Token::Kind op) { 255 switch (op) { 256 case Token::STAR: // fall through 257 case Token::SLASH: // fall through 258 case Token::PERCENT: return MetalCodeGenerator::kMultiplicative_Precedence; 259 case Token::PLUS: // fall through 260 case Token::MINUS: return MetalCodeGenerator::kAdditive_Precedence; 261 case Token::SHL: // fall through 262 case Token::SHR: return MetalCodeGenerator::kShift_Precedence; 263 case Token::LT: // fall through 264 case Token::GT: // fall through 265 case Token::LTEQ: // fall through 266 case Token::GTEQ: return MetalCodeGenerator::kRelational_Precedence; 267 case Token::EQEQ: // fall through 268 case Token::NEQ: return MetalCodeGenerator::kEquality_Precedence; 269 case Token::BITWISEAND: return MetalCodeGenerator::kBitwiseAnd_Precedence; 270 case Token::BITWISEXOR: return MetalCodeGenerator::kBitwiseXor_Precedence; 271 case Token::BITWISEOR: return MetalCodeGenerator::kBitwiseOr_Precedence; 272 case Token::LOGICALAND: return MetalCodeGenerator::kLogicalAnd_Precedence; 273 case Token::LOGICALXOR: return MetalCodeGenerator::kLogicalXor_Precedence; 274 case Token::LOGICALOR: return MetalCodeGenerator::kLogicalOr_Precedence; 275 case Token::EQ: // fall through 276 case Token::PLUSEQ: // fall through 277 case Token::MINUSEQ: // fall through 278 case Token::STAREQ: // fall through 279 case Token::SLASHEQ: // fall through 280 case Token::PERCENTEQ: // fall through 281 case Token::SHLEQ: // fall through 282 case Token::SHREQ: // fall through 283 case Token::LOGICALANDEQ: // fall through 284 case Token::LOGICALXOREQ: // fall through 285 case Token::LOGICALOREQ: // fall through 286 case Token::BITWISEANDEQ: // fall through 287 case Token::BITWISEXOREQ: // fall through 288 case Token::BITWISEOREQ: return MetalCodeGenerator::kAssignment_Precedence; 289 case Token::COMMA: return MetalCodeGenerator::kSequence_Precedence; 290 default: ABORT("unsupported binary operator"); 291 } 292 } 293 294 void MetalCodeGenerator::writeBinaryExpression(const BinaryExpression& b, 295 Precedence parentPrecedence) { 296 Precedence precedence = GetBinaryPrecedence(b.fOperator); 297 if (precedence >= parentPrecedence) { 298 this->write("("); 299 } 300 if (Compiler::IsAssignment(b.fOperator) && 301 Expression::kVariableReference_Kind == b.fLeft->fKind && 302 Variable::kParameter_Storage == ((VariableReference&) *b.fLeft).fVariable.fStorage && 303 (((VariableReference&) *b.fLeft).fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) { 304 // writing to an out parameter. Since we have to turn those into pointers, we have to 305 // dereference it here. 306 this->write("*"); 307 } 308 this->writeExpression(*b.fLeft, precedence); 309 if (b.fOperator != Token::EQ && Compiler::IsAssignment(b.fOperator) && 310 Expression::kSwizzle_Kind == b.fLeft->fKind && !b.fLeft->hasSideEffects()) { 311 // This doesn't compile in Metal: 312 // float4 x = float4(1); 313 // x.xy *= float2x2(...); 314 // with the error message "non-const reference cannot bind to vector element", 315 // but switching it to x.xy = x.xy * float2x2(...) fixes it. We perform this tranformation 316 // as long as the LHS has no side effects, and hope for the best otherwise. 317 this->write(" = "); 318 this->writeExpression(*b.fLeft, kAssignment_Precedence); 319 this->write(" "); 320 String op = Compiler::OperatorName(b.fOperator); 321 ASSERT(op.endsWith("=")); 322 this->write(op.substr(0, op.size() - 1).c_str()); 323 this->write(" "); 324 } else { 325 this->write(String(" ") + Compiler::OperatorName(b.fOperator) + " "); 326 } 327 this->writeExpression(*b.fRight, precedence); 328 if (precedence >= parentPrecedence) { 329 this->write(")"); 330 } 331 } 332 333 void MetalCodeGenerator::writeTernaryExpression(const TernaryExpression& t, 334 Precedence parentPrecedence) { 335 if (kTernary_Precedence >= parentPrecedence) { 336 this->write("("); 337 } 338 this->writeExpression(*t.fTest, kTernary_Precedence); 339 this->write(" ? "); 340 this->writeExpression(*t.fIfTrue, kTernary_Precedence); 341 this->write(" : "); 342 this->writeExpression(*t.fIfFalse, kTernary_Precedence); 343 if (kTernary_Precedence >= parentPrecedence) { 344 this->write(")"); 345 } 346 } 347 348 void MetalCodeGenerator::writePrefixExpression(const PrefixExpression& p, 349 Precedence parentPrecedence) { 350 if (kPrefix_Precedence >= parentPrecedence) { 351 this->write("("); 352 } 353 this->write(Compiler::OperatorName(p.fOperator)); 354 this->writeExpression(*p.fOperand, kPrefix_Precedence); 355 if (kPrefix_Precedence >= parentPrecedence) { 356 this->write(")"); 357 } 358 } 359 360 void MetalCodeGenerator::writePostfixExpression(const PostfixExpression& p, 361 Precedence parentPrecedence) { 362 if (kPostfix_Precedence >= parentPrecedence) { 363 this->write("("); 364 } 365 this->writeExpression(*p.fOperand, kPostfix_Precedence); 366 this->write(Compiler::OperatorName(p.fOperator)); 367 if (kPostfix_Precedence >= parentPrecedence) { 368 this->write(")"); 369 } 370 } 371 372 void MetalCodeGenerator::writeBoolLiteral(const BoolLiteral& b) { 373 this->write(b.fValue ? "true" : "false"); 374 } 375 376 void MetalCodeGenerator::writeIntLiteral(const IntLiteral& i) { 377 if (i.fType == *fContext.fUInt_Type) { 378 this->write(to_string(i.fValue & 0xffffffff) + "u"); 379 } else { 380 this->write(to_string((int32_t) i.fValue)); 381 } 382 } 383 384 void MetalCodeGenerator::writeFloatLiteral(const FloatLiteral& f) { 385 this->write(to_string(f.fValue)); 386 } 387 388 void MetalCodeGenerator::writeSetting(const Setting& s) { 389 ABORT("internal error; setting was not folded to a constant during compilation\n"); 390 } 391 392 void MetalCodeGenerator::writeFunction(const FunctionDefinition& f) { 393 const char* separator = ""; 394 if ("main" == f.fDeclaration.fName) { 395 switch (fProgram.fKind) { 396 case Program::kFragment_Kind: 397 this->write("fragment half4 _frag"); 398 break; 399 case Program::kVertex_Kind: 400 this->write("vertex Outputs _vert"); 401 break; 402 default: 403 ASSERT(false); 404 } 405 this->write("(Inputs _in [[stage_in]]"); 406 if (-1 != fUniformBuffer) { 407 this->write(", constant Uniforms& _uniforms [[buffer(" + 408 to_string(fUniformBuffer) + ")]]"); 409 } 410 separator = ", "; 411 } else { 412 this->writeType(f.fDeclaration.fReturnType); 413 this->write(" " + f.fDeclaration.fName + "("); 414 if (this->requirements(f.fDeclaration) & kInputs_Requirement) { 415 this->write("Inputs _in"); 416 separator = ", "; 417 } 418 if (this->requirements(f.fDeclaration) & kOutputs_Requirement) { 419 this->write(separator); 420 this->write("thread Outputs& _out"); 421 separator = ", "; 422 } 423 if (this->requirements(f.fDeclaration) & kUniforms_Requirement) { 424 this->write(separator); 425 this->write("Uniforms _uniforms"); 426 separator = ", "; 427 } 428 } 429 for (const auto& param : f.fDeclaration.fParameters) { 430 this->write(separator); 431 separator = ", "; 432 this->writeModifiers(param->fModifiers, false); 433 std::vector<int> sizes; 434 const Type* type = ¶m->fType; 435 while (Type::kArray_Kind == type->kind()) { 436 sizes.push_back(type->columns()); 437 type = &type->componentType(); 438 } 439 this->writeType(*type); 440 if (param->fModifiers.fFlags & Modifiers::kOut_Flag) { 441 this->write("*"); 442 } 443 this->write(" " + param->fName); 444 for (int s : sizes) { 445 if (s <= 0) { 446 this->write("[]"); 447 } else { 448 this->write("[" + to_string(s) + "]"); 449 } 450 } 451 } 452 this->writeLine(") {"); 453 454 ASSERT(!fProgram.fSettings.fFragColorIsInOut); 455 456 if ("main" == f.fDeclaration.fName) { 457 switch (fProgram.fKind) { 458 case Program::kFragment_Kind: 459 this->writeLine(" half4 sk_FragColor;"); 460 break; 461 case Program::kVertex_Kind: 462 this->writeLine(" Outputs _out;"); 463 break; 464 default: 465 ASSERT(false); 466 } 467 } 468 fFunctionHeader = ""; 469 OutputStream* oldOut = fOut; 470 StringStream buffer; 471 fOut = &buffer; 472 fIndentation++; 473 this->writeStatements(((Block&) *f.fBody).fStatements); 474 if ("main" == f.fDeclaration.fName) { 475 switch (fProgram.fKind) { 476 case Program::kFragment_Kind: 477 this->writeLine("return sk_FragColor;"); 478 break; 479 case Program::kVertex_Kind: 480 this->writeLine("return _out;"); 481 break; 482 default: 483 ASSERT(false); 484 } 485 } 486 fIndentation--; 487 this->writeLine("}"); 488 489 fOut = oldOut; 490 this->write(fFunctionHeader); 491 this->write(buffer.str()); 492 } 493 494 void MetalCodeGenerator::writeModifiers(const Modifiers& modifiers, 495 bool globalContext) { 496 if (modifiers.fFlags & Modifiers::kOut_Flag) { 497 this->write("thread "); 498 } 499 if (modifiers.fFlags & Modifiers::kConst_Flag) { 500 this->write("const "); 501 } 502 } 503 504 void MetalCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) { 505 if ("sk_PerVertex" == intf.fTypeName) { 506 return; 507 } 508 this->writeModifiers(intf.fVariable.fModifiers, true); 509 this->writeLine(intf.fTypeName + " {"); 510 fIndentation++; 511 const Type* structType = &intf.fVariable.fType; 512 while (Type::kArray_Kind == structType->kind()) { 513 structType = &structType->componentType(); 514 } 515 for (const auto& f : structType->fields()) { 516 this->writeModifiers(f.fModifiers, false); 517 this->writeType(*f.fType); 518 this->writeLine(" " + f.fName + ";"); 519 } 520 fIndentation--; 521 this->write("}"); 522 if (intf.fInstanceName.size()) { 523 this->write(" "); 524 this->write(intf.fInstanceName); 525 for (const auto& size : intf.fSizes) { 526 this->write("["); 527 if (size) { 528 this->writeExpression(*size, kTopLevel_Precedence); 529 } 530 this->write("]"); 531 } 532 } 533 this->writeLine(";"); 534 } 535 536 void MetalCodeGenerator::writeVarInitializer(const Variable& var, const Expression& value) { 537 this->writeExpression(value, kTopLevel_Precedence); 538 } 539 540 void MetalCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, bool global) { 541 ASSERT(decl.fVars.size() > 0); 542 bool wroteType = false; 543 for (const auto& stmt : decl.fVars) { 544 VarDeclaration& var = (VarDeclaration&) *stmt; 545 if (var.fVar->fModifiers.fFlags & (Modifiers::kIn_Flag | Modifiers::kOut_Flag | 546 Modifiers::kUniform_Flag)) { 547 ASSERT(global); 548 continue; 549 } 550 if (wroteType) { 551 this->write(", "); 552 } else { 553 this->writeModifiers(var.fVar->fModifiers, global); 554 this->writeType(decl.fBaseType); 555 this->write(" "); 556 wroteType = true; 557 } 558 this->write(var.fVar->fName); 559 for (const auto& size : var.fSizes) { 560 this->write("["); 561 if (size) { 562 this->writeExpression(*size, kTopLevel_Precedence); 563 } 564 this->write("]"); 565 } 566 if (var.fValue) { 567 this->write(" = "); 568 this->writeVarInitializer(*var.fVar, *var.fValue); 569 } 570 if (!fFoundImageDecl && var.fVar->fType == *fContext.fImage2D_Type) { 571 if (fProgram.fSettings.fCaps->imageLoadStoreExtensionString()) { 572 fHeader.writeText("#extension "); 573 fHeader.writeText(fProgram.fSettings.fCaps->imageLoadStoreExtensionString()); 574 fHeader.writeText(" : require\n"); 575 } 576 fFoundImageDecl = true; 577 } 578 } 579 if (wroteType) { 580 this->write(";"); 581 } 582 } 583 584 void MetalCodeGenerator::writeStatement(const Statement& s) { 585 switch (s.fKind) { 586 case Statement::kBlock_Kind: 587 this->writeBlock((Block&) s); 588 break; 589 case Statement::kExpression_Kind: 590 this->writeExpression(*((ExpressionStatement&) s).fExpression, kTopLevel_Precedence); 591 this->write(";"); 592 break; 593 case Statement::kReturn_Kind: 594 this->writeReturnStatement((ReturnStatement&) s); 595 break; 596 case Statement::kVarDeclarations_Kind: 597 this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, false); 598 break; 599 case Statement::kIf_Kind: 600 this->writeIfStatement((IfStatement&) s); 601 break; 602 case Statement::kFor_Kind: 603 this->writeForStatement((ForStatement&) s); 604 break; 605 case Statement::kWhile_Kind: 606 this->writeWhileStatement((WhileStatement&) s); 607 break; 608 case Statement::kDo_Kind: 609 this->writeDoStatement((DoStatement&) s); 610 break; 611 case Statement::kSwitch_Kind: 612 this->writeSwitchStatement((SwitchStatement&) s); 613 break; 614 case Statement::kBreak_Kind: 615 this->write("break;"); 616 break; 617 case Statement::kContinue_Kind: 618 this->write("continue;"); 619 break; 620 case Statement::kDiscard_Kind: 621 this->write("discard;"); 622 break; 623 case Statement::kNop_Kind: 624 this->write(";"); 625 break; 626 default: 627 ABORT("unsupported statement: %s", s.description().c_str()); 628 } 629 } 630 631 void MetalCodeGenerator::writeStatements(const std::vector<std::unique_ptr<Statement>>& statements) { 632 for (const auto& s : statements) { 633 if (!s->isEmpty()) { 634 this->writeStatement(*s); 635 this->writeLine(); 636 } 637 } 638 } 639 640 void MetalCodeGenerator::writeBlock(const Block& b) { 641 this->writeLine("{"); 642 fIndentation++; 643 this->writeStatements(b.fStatements); 644 fIndentation--; 645 this->write("}"); 646 } 647 648 void MetalCodeGenerator::writeIfStatement(const IfStatement& stmt) { 649 this->write("if ("); 650 this->writeExpression(*stmt.fTest, kTopLevel_Precedence); 651 this->write(") "); 652 this->writeStatement(*stmt.fIfTrue); 653 if (stmt.fIfFalse) { 654 this->write(" else "); 655 this->writeStatement(*stmt.fIfFalse); 656 } 657 } 658 659 void MetalCodeGenerator::writeForStatement(const ForStatement& f) { 660 this->write("for ("); 661 if (f.fInitializer && !f.fInitializer->isEmpty()) { 662 this->writeStatement(*f.fInitializer); 663 } else { 664 this->write("; "); 665 } 666 if (f.fTest) { 667 this->writeExpression(*f.fTest, kTopLevel_Precedence); 668 } 669 this->write("; "); 670 if (f.fNext) { 671 this->writeExpression(*f.fNext, kTopLevel_Precedence); 672 } 673 this->write(") "); 674 this->writeStatement(*f.fStatement); 675 } 676 677 void MetalCodeGenerator::writeWhileStatement(const WhileStatement& w) { 678 this->write("while ("); 679 this->writeExpression(*w.fTest, kTopLevel_Precedence); 680 this->write(") "); 681 this->writeStatement(*w.fStatement); 682 } 683 684 void MetalCodeGenerator::writeDoStatement(const DoStatement& d) { 685 this->write("do "); 686 this->writeStatement(*d.fStatement); 687 this->write(" while ("); 688 this->writeExpression(*d.fTest, kTopLevel_Precedence); 689 this->write(");"); 690 } 691 692 void MetalCodeGenerator::writeSwitchStatement(const SwitchStatement& s) { 693 this->write("switch ("); 694 this->writeExpression(*s.fValue, kTopLevel_Precedence); 695 this->writeLine(") {"); 696 fIndentation++; 697 for (const auto& c : s.fCases) { 698 if (c->fValue) { 699 this->write("case "); 700 this->writeExpression(*c->fValue, kTopLevel_Precedence); 701 this->writeLine(":"); 702 } else { 703 this->writeLine("default:"); 704 } 705 fIndentation++; 706 for (const auto& stmt : c->fStatements) { 707 this->writeStatement(*stmt); 708 this->writeLine(); 709 } 710 fIndentation--; 711 } 712 fIndentation--; 713 this->write("}"); 714 } 715 716 void MetalCodeGenerator::writeReturnStatement(const ReturnStatement& r) { 717 this->write("return"); 718 if (r.fExpression) { 719 this->write(" "); 720 this->writeExpression(*r.fExpression, kTopLevel_Precedence); 721 } 722 this->write(";"); 723 } 724 725 void MetalCodeGenerator::writeHeader() { 726 this->write("#include <metal_stdlib>\n"); 727 this->write("#include <simd/simd.h>\n"); 728 this->write("using namespace metal;\n"); 729 } 730 731 void MetalCodeGenerator::writeUniformStruct() { 732 for (const auto& e : fProgram.fElements) { 733 if (ProgramElement::kVar_Kind == e->fKind) { 734 VarDeclarations& decls = (VarDeclarations&) *e; 735 if (!decls.fVars.size()) { 736 continue; 737 } 738 const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; 739 if (first.fModifiers.fFlags & Modifiers::kUniform_Flag) { 740 if (-1 == fUniformBuffer) { 741 this->write("struct Uniforms {\n"); 742 fUniformBuffer = first.fModifiers.fLayout.fSet; 743 if (-1 == fUniformBuffer) { 744 fErrors.error(decls.fOffset, "Metal uniforms must have 'layout(set=...)'"); 745 } 746 } else if (first.fModifiers.fLayout.fSet != fUniformBuffer) { 747 if (-1 == fUniformBuffer) { 748 fErrors.error(decls.fOffset, "Metal backend requires all uniforms to have " 749 "the same 'layout(set=...)'"); 750 } 751 } 752 this->write(" "); 753 this->writeType(first.fType); 754 this->write(" "); 755 for (const auto& stmt : decls.fVars) { 756 VarDeclaration& var = (VarDeclaration&) *stmt; 757 this->write(var.fVar->fName); 758 } 759 this->write(";\n"); 760 } 761 } 762 } 763 if (-1 != fUniformBuffer) { 764 this->write("};\n"); 765 } 766 } 767 768 void MetalCodeGenerator::writeInputStruct() { 769 this->write("struct Inputs {\n"); 770 if (Program::kFragment_Kind == fProgram.fKind) { 771 this->write(" float4 position [[position]];\n"); 772 } 773 for (const auto& e : fProgram.fElements) { 774 if (ProgramElement::kVar_Kind == e->fKind) { 775 VarDeclarations& decls = (VarDeclarations&) *e; 776 if (!decls.fVars.size()) { 777 continue; 778 } 779 const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; 780 if (first.fModifiers.fFlags & Modifiers::kIn_Flag && 781 -1 == first.fModifiers.fLayout.fBuiltin) { 782 this->write(" "); 783 this->writeType(first.fType); 784 this->write(" "); 785 for (const auto& stmt : decls.fVars) { 786 VarDeclaration& var = (VarDeclaration&) *stmt; 787 this->write(var.fVar->fName); 788 if (-1 != var.fVar->fModifiers.fLayout.fLocation) { 789 this->write(" [[attribute(" + 790 to_string(var.fVar->fModifiers.fLayout.fLocation) + ")]]"); 791 } 792 } 793 this->write(";\n"); 794 } 795 } 796 } 797 this->write("};\n"); 798 } 799 800 void MetalCodeGenerator::writeOutputStruct() { 801 this->write("struct Outputs {\n"); 802 this->write(" float4 position [[position]];\n"); 803 for (const auto& e : fProgram.fElements) { 804 if (ProgramElement::kVar_Kind == e->fKind) { 805 VarDeclarations& decls = (VarDeclarations&) *e; 806 if (!decls.fVars.size()) { 807 continue; 808 } 809 const Variable& first = *((VarDeclaration&) *decls.fVars[0]).fVar; 810 if (first.fModifiers.fFlags & Modifiers::kOut_Flag && 811 -1 == first.fModifiers.fLayout.fBuiltin) { 812 this->write(" "); 813 this->writeType(first.fType); 814 this->write(" "); 815 for (const auto& stmt : decls.fVars) { 816 VarDeclaration& var = (VarDeclaration&) *stmt; 817 this->write(var.fVar->fName); 818 } 819 this->write(";\n"); 820 } 821 } 822 } this->write("};\n"); 823 } 824 825 void MetalCodeGenerator::writeProgramElement(const ProgramElement& e) { 826 switch (e.fKind) { 827 case ProgramElement::kExtension_Kind: 828 break; 829 case ProgramElement::kVar_Kind: { 830 VarDeclarations& decl = (VarDeclarations&) e; 831 if (decl.fVars.size() > 0) { 832 int builtin = ((VarDeclaration&) *decl.fVars[0]).fVar->fModifiers.fLayout.fBuiltin; 833 if (-1 == builtin) { 834 // normal var 835 this->writeVarDeclarations(decl, true); 836 this->writeLine(); 837 } else if (SK_FRAGCOLOR_BUILTIN == builtin) { 838 // ignore 839 } 840 } 841 break; 842 } 843 case ProgramElement::kInterfaceBlock_Kind: 844 this->writeInterfaceBlock((InterfaceBlock&) e); 845 break; 846 case ProgramElement::kFunction_Kind: 847 this->writeFunction((FunctionDefinition&) e); 848 break; 849 case ProgramElement::kModifiers_Kind: 850 this->writeModifiers(((ModifiersDeclaration&) e).fModifiers, true); 851 this->writeLine(";"); 852 break; 853 default: 854 printf("%s\n", e.description().c_str()); 855 ABORT("unsupported program element"); 856 } 857 } 858 859 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Expression& e) { 860 switch (e.fKind) { 861 case Expression::kFunctionCall_Kind: { 862 const FunctionCall& f = (const FunctionCall&) e; 863 Requirements result = this->requirements(f.fFunction); 864 for (const auto& e : f.fArguments) { 865 result |= this->requirements(*e); 866 } 867 return result; 868 } 869 case Expression::kConstructor_Kind: { 870 const Constructor& c = (const Constructor&) e; 871 Requirements result = kNo_Requirements; 872 for (const auto& e : c.fArguments) { 873 result |= this->requirements(*e); 874 } 875 return result; 876 } 877 case Expression::kFieldAccess_Kind: 878 return this->requirements(*((const FieldAccess&) e).fBase); 879 case Expression::kSwizzle_Kind: 880 return this->requirements(*((const Swizzle&) e).fBase); 881 case Expression::kBinary_Kind: { 882 const BinaryExpression& b = (const BinaryExpression&) e; 883 return this->requirements(*b.fLeft) | this->requirements(*b.fRight); 884 } 885 case Expression::kIndex_Kind: { 886 const IndexExpression& idx = (const IndexExpression&) e; 887 return this->requirements(*idx.fBase) | this->requirements(*idx.fIndex); 888 } 889 case Expression::kPrefix_Kind: 890 return this->requirements(*((const PrefixExpression&) e).fOperand); 891 case Expression::kPostfix_Kind: 892 return this->requirements(*((const PostfixExpression&) e).fOperand); 893 case Expression::kTernary_Kind: { 894 const TernaryExpression& t = (const TernaryExpression&) e; 895 return this->requirements(*t.fTest) | this->requirements(*t.fIfTrue) | 896 this->requirements(*t.fIfFalse); 897 } 898 case Expression::kVariableReference_Kind: { 899 const VariableReference& v = (const VariableReference&) e; 900 Requirements result = kNo_Requirements; 901 if (v.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN) { 902 result = kInputs_Requirement; 903 } else if (Variable::kGlobal_Storage == v.fVariable.fStorage) { 904 if (v.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) { 905 result = kInputs_Requirement; 906 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag) { 907 result = kOutputs_Requirement; 908 } else if (v.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag) { 909 result = kUniforms_Requirement; 910 } 911 } 912 return result; 913 } 914 default: 915 return kNo_Requirements; 916 } 917 } 918 919 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const Statement& s) { 920 switch (s.fKind) { 921 case Statement::kBlock_Kind: { 922 Requirements result = kNo_Requirements; 923 for (const auto& child : ((const Block&) s).fStatements) { 924 result |= this->requirements(*child); 925 } 926 return result; 927 } 928 case Statement::kExpression_Kind: 929 return this->requirements(*((const ExpressionStatement&) s).fExpression); 930 case Statement::kReturn_Kind: { 931 const ReturnStatement& r = (const ReturnStatement&) s; 932 if (r.fExpression) { 933 return this->requirements(*r.fExpression); 934 } 935 return kNo_Requirements; 936 } 937 case Statement::kIf_Kind: { 938 const IfStatement& i = (const IfStatement&) s; 939 return this->requirements(*i.fTest) | 940 this->requirements(*i.fIfTrue) | 941 (i.fIfFalse && this->requirements(*i.fIfFalse)); 942 } 943 case Statement::kFor_Kind: { 944 const ForStatement& f = (const ForStatement&) s; 945 return this->requirements(*f.fInitializer) | 946 this->requirements(*f.fTest) | 947 this->requirements(*f.fNext) | 948 this->requirements(*f.fStatement); 949 } 950 case Statement::kWhile_Kind: { 951 const WhileStatement& w = (const WhileStatement&) s; 952 return this->requirements(*w.fTest) | 953 this->requirements(*w.fStatement); 954 } 955 case Statement::kDo_Kind: { 956 const DoStatement& d = (const DoStatement&) s; 957 return this->requirements(*d.fTest) | 958 this->requirements(*d.fStatement); 959 } 960 case Statement::kSwitch_Kind: { 961 const SwitchStatement& sw = (const SwitchStatement&) s; 962 Requirements result = this->requirements(*sw.fValue); 963 for (const auto& c : sw.fCases) { 964 for (const auto& st : c->fStatements) { 965 result |= this->requirements(*st); 966 } 967 } 968 return result; 969 } 970 default: 971 return kNo_Requirements; 972 } 973 } 974 975 MetalCodeGenerator::Requirements MetalCodeGenerator::requirements(const FunctionDeclaration& f) { 976 if (f.fBuiltin) { 977 return kNo_Requirements; 978 } 979 auto found = fRequirements.find(&f); 980 if (found == fRequirements.end()) { 981 for (const auto& e : fProgram.fElements) { 982 if (ProgramElement::kFunction_Kind == e->fKind) { 983 const FunctionDefinition& def = (const FunctionDefinition&) *e; 984 if (&def.fDeclaration == &f) { 985 Requirements reqs = this->requirements(*def.fBody); 986 fRequirements[&f] = reqs; 987 return reqs; 988 } 989 } 990 } 991 } 992 return found->second; 993 } 994 995 bool MetalCodeGenerator::generateCode() { 996 OutputStream* rawOut = fOut; 997 fOut = &fHeader; 998 fProgramKind = fProgram.fKind; 999 this->writeHeader(); 1000 this->writeUniformStruct(); 1001 this->writeInputStruct(); 1002 if (Program::kVertex_Kind == fProgram.fKind) { 1003 this->writeOutputStruct(); 1004 } 1005 StringStream body; 1006 fOut = &body; 1007 for (const auto& e : fProgram.fElements) { 1008 this->writeProgramElement(*e); 1009 } 1010 fOut = rawOut; 1011 1012 write_stringstream(fHeader, *rawOut); 1013 write_stringstream(body, *rawOut); 1014 return true; 1015 } 1016 1017 } 1018