1 /* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8 #include "SkSLSPIRVCodeGenerator.h" 9 10 #include "GLSL.std.450.h" 11 12 #include "ir/SkSLExpressionStatement.h" 13 #include "ir/SkSLExtension.h" 14 #include "ir/SkSLIndexExpression.h" 15 #include "ir/SkSLVariableReference.h" 16 #include "SkSLCompiler.h" 17 18 namespace SkSL { 19 20 #define SPIRV_DEBUG 0 21 22 static const int32_t SKSL_MAGIC = 0x0; // FIXME: we should probably register a magic number 23 24 void SPIRVCodeGenerator::setupIntrinsics() { 25 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \ 26 GLSLstd450 ## x, GLSLstd450 ## x) 27 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \ 28 GLSLstd450 ## ifFloat, \ 29 GLSLstd450 ## ifInt, \ 30 GLSLstd450 ## ifUInt, \ 31 SpvOpUndef) 32 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \ 33 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \ 34 k ## x ## _SpecialIntrinsic) 35 fIntrinsicMap[String("round")] = ALL_GLSL(Round); 36 fIntrinsicMap[String("roundEven")] = ALL_GLSL(RoundEven); 37 fIntrinsicMap[String("trunc")] = ALL_GLSL(Trunc); 38 fIntrinsicMap[String("abs")] = BY_TYPE_GLSL(FAbs, SAbs, SAbs); 39 fIntrinsicMap[String("sign")] = BY_TYPE_GLSL(FSign, SSign, SSign); 40 fIntrinsicMap[String("floor")] = ALL_GLSL(Floor); 41 fIntrinsicMap[String("ceil")] = ALL_GLSL(Ceil); 42 fIntrinsicMap[String("fract")] = ALL_GLSL(Fract); 43 fIntrinsicMap[String("radians")] = ALL_GLSL(Radians); 44 fIntrinsicMap[String("degrees")] = ALL_GLSL(Degrees); 45 fIntrinsicMap[String("sin")] = ALL_GLSL(Sin); 46 fIntrinsicMap[String("cos")] = ALL_GLSL(Cos); 47 fIntrinsicMap[String("tan")] = ALL_GLSL(Tan); 48 fIntrinsicMap[String("asin")] = ALL_GLSL(Asin); 49 fIntrinsicMap[String("acos")] = ALL_GLSL(Acos); 50 fIntrinsicMap[String("atan")] = SPECIAL(Atan); 51 fIntrinsicMap[String("sinh")] = ALL_GLSL(Sinh); 52 fIntrinsicMap[String("cosh")] = ALL_GLSL(Cosh); 53 fIntrinsicMap[String("tanh")] = ALL_GLSL(Tanh); 54 fIntrinsicMap[String("asinh")] = ALL_GLSL(Asinh); 55 fIntrinsicMap[String("acosh")] = ALL_GLSL(Acosh); 56 fIntrinsicMap[String("atanh")] = ALL_GLSL(Atanh); 57 fIntrinsicMap[String("pow")] = ALL_GLSL(Pow); 58 fIntrinsicMap[String("exp")] = ALL_GLSL(Exp); 59 fIntrinsicMap[String("log")] = ALL_GLSL(Log); 60 fIntrinsicMap[String("exp2")] = ALL_GLSL(Exp2); 61 fIntrinsicMap[String("log2")] = ALL_GLSL(Log2); 62 fIntrinsicMap[String("sqrt")] = ALL_GLSL(Sqrt); 63 fIntrinsicMap[String("inversesqrt")] = ALL_GLSL(InverseSqrt); 64 fIntrinsicMap[String("determinant")] = ALL_GLSL(Determinant); 65 fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse); 66 fIntrinsicMap[String("mod")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFMod, 67 SpvOpSMod, SpvOpUMod, SpvOpUndef); 68 fIntrinsicMap[String("min")] = BY_TYPE_GLSL(FMin, SMin, UMin); 69 fIntrinsicMap[String("max")] = BY_TYPE_GLSL(FMax, SMax, UMax); 70 fIntrinsicMap[String("clamp")] = BY_TYPE_GLSL(FClamp, SClamp, UClamp); 71 fIntrinsicMap[String("dot")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot, 72 SpvOpUndef, SpvOpUndef, SpvOpUndef); 73 fIntrinsicMap[String("mix")] = ALL_GLSL(FMix); 74 fIntrinsicMap[String("step")] = ALL_GLSL(Step); 75 fIntrinsicMap[String("smoothstep")] = ALL_GLSL(SmoothStep); 76 fIntrinsicMap[String("fma")] = ALL_GLSL(Fma); 77 fIntrinsicMap[String("frexp")] = ALL_GLSL(Frexp); 78 fIntrinsicMap[String("ldexp")] = ALL_GLSL(Ldexp); 79 80 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \ 81 fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type) 82 PACK(Snorm4x8); 83 PACK(Unorm4x8); 84 PACK(Snorm2x16); 85 PACK(Unorm2x16); 86 PACK(Half2x16); 87 PACK(Double2x32); 88 fIntrinsicMap[String("length")] = ALL_GLSL(Length); 89 fIntrinsicMap[String("distance")] = ALL_GLSL(Distance); 90 fIntrinsicMap[String("cross")] = ALL_GLSL(Cross); 91 fIntrinsicMap[String("normalize")] = ALL_GLSL(Normalize); 92 fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward); 93 fIntrinsicMap[String("reflect")] = ALL_GLSL(Reflect); 94 fIntrinsicMap[String("refract")] = ALL_GLSL(Refract); 95 fIntrinsicMap[String("findLSB")] = ALL_GLSL(FindILsb); 96 fIntrinsicMap[String("findMSB")] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb); 97 fIntrinsicMap[String("dFdx")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx, 98 SpvOpUndef, SpvOpUndef, SpvOpUndef); 99 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy, 100 SpvOpUndef, SpvOpUndef, SpvOpUndef); 101 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy, 102 SpvOpUndef, SpvOpUndef, SpvOpUndef); 103 fIntrinsicMap[String("texture")] = SPECIAL(Texture); 104 fIntrinsicMap[String("texelFetch")] = SPECIAL(TexelFetch); 105 fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad); 106 107 fIntrinsicMap[String("any")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef, 108 SpvOpUndef, SpvOpUndef, SpvOpAny); 109 fIntrinsicMap[String("all")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef, 110 SpvOpUndef, SpvOpUndef, SpvOpAll); 111 fIntrinsicMap[String("equal")] = std::make_tuple(kSPIRV_IntrinsicKind, 112 SpvOpFOrdEqual, SpvOpIEqual, 113 SpvOpIEqual, SpvOpLogicalEqual); 114 fIntrinsicMap[String("notEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 115 SpvOpFOrdNotEqual, SpvOpINotEqual, 116 SpvOpINotEqual, 117 SpvOpLogicalNotEqual); 118 fIntrinsicMap[String("lessThan")] = std::make_tuple(kSPIRV_IntrinsicKind, 119 SpvOpSLessThan, SpvOpULessThan, 120 SpvOpFOrdLessThan, SpvOpUndef); 121 fIntrinsicMap[String("lessThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 122 SpvOpSLessThanEqual, 123 SpvOpULessThanEqual, 124 SpvOpFOrdLessThanEqual, 125 SpvOpUndef); 126 fIntrinsicMap[String("greaterThan")] = std::make_tuple(kSPIRV_IntrinsicKind, 127 SpvOpSGreaterThan, 128 SpvOpUGreaterThan, 129 SpvOpFOrdGreaterThan, 130 SpvOpUndef); 131 fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 132 SpvOpSGreaterThanEqual, 133 SpvOpUGreaterThanEqual, 134 SpvOpFOrdGreaterThanEqual, 135 SpvOpUndef); 136 // interpolateAt* not yet supported... 137 } 138 139 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) { 140 #if SPIRV_DEBUG 141 out << "(" << word << ") "; 142 #else 143 out.write((const char*) &word, sizeof(word)); 144 #endif 145 } 146 147 static bool is_float(const Context& context, const Type& type) { 148 if (type.kind() == Type::kVector_Kind) { 149 return is_float(context, type.componentType()); 150 } 151 return type == *context.fFloat_Type || type == *context.fDouble_Type; 152 } 153 154 static bool is_signed(const Context& context, const Type& type) { 155 if (type.kind() == Type::kVector_Kind) { 156 return is_signed(context, type.componentType()); 157 } 158 return type == *context.fInt_Type; 159 } 160 161 static bool is_unsigned(const Context& context, const Type& type) { 162 if (type.kind() == Type::kVector_Kind) { 163 return is_unsigned(context, type.componentType()); 164 } 165 return type == *context.fUInt_Type; 166 } 167 168 static bool is_bool(const Context& context, const Type& type) { 169 if (type.kind() == Type::kVector_Kind) { 170 return is_bool(context, type.componentType()); 171 } 172 return type == *context.fBool_Type; 173 } 174 175 static bool is_out(const Variable& var) { 176 return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0; 177 } 178 179 #if SPIRV_DEBUG 180 static String opcode_text(SpvOp_ opCode) { 181 switch (opCode) { 182 case SpvOpNop: 183 return String("Nop"); 184 case SpvOpUndef: 185 return String("Undef"); 186 case SpvOpSourceContinued: 187 return String("SourceContinued"); 188 case SpvOpSource: 189 return String("Source"); 190 case SpvOpSourceExtension: 191 return String("SourceExtension"); 192 case SpvOpName: 193 return String("Name"); 194 case SpvOpMemberName: 195 return String("MemberName"); 196 case SpvOpString: 197 return String("String"); 198 case SpvOpLine: 199 return String("Line"); 200 case SpvOpExtension: 201 return String("Extension"); 202 case SpvOpExtInstImport: 203 return String("ExtInstImport"); 204 case SpvOpExtInst: 205 return String("ExtInst"); 206 case SpvOpMemoryModel: 207 return String("MemoryModel"); 208 case SpvOpEntryPoint: 209 return String("EntryPoint"); 210 case SpvOpExecutionMode: 211 return String("ExecutionMode"); 212 case SpvOpCapability: 213 return String("Capability"); 214 case SpvOpTypeVoid: 215 return String("TypeVoid"); 216 case SpvOpTypeBool: 217 return String("TypeBool"); 218 case SpvOpTypeInt: 219 return String("TypeInt"); 220 case SpvOpTypeFloat: 221 return String("TypeFloat"); 222 case SpvOpTypeVector: 223 return String("TypeVector"); 224 case SpvOpTypeMatrix: 225 return String("TypeMatrix"); 226 case SpvOpTypeImage: 227 return String("TypeImage"); 228 case SpvOpTypeSampler: 229 return String("TypeSampler"); 230 case SpvOpTypeSampledImage: 231 return String("TypeSampledImage"); 232 case SpvOpTypeArray: 233 return String("TypeArray"); 234 case SpvOpTypeRuntimeArray: 235 return String("TypeRuntimeArray"); 236 case SpvOpTypeStruct: 237 return String("TypeStruct"); 238 case SpvOpTypeOpaque: 239 return String("TypeOpaque"); 240 case SpvOpTypePointer: 241 return String("TypePointer"); 242 case SpvOpTypeFunction: 243 return String("TypeFunction"); 244 case SpvOpTypeEvent: 245 return String("TypeEvent"); 246 case SpvOpTypeDeviceEvent: 247 return String("TypeDeviceEvent"); 248 case SpvOpTypeReserveId: 249 return String("TypeReserveId"); 250 case SpvOpTypeQueue: 251 return String("TypeQueue"); 252 case SpvOpTypePipe: 253 return String("TypePipe"); 254 case SpvOpTypeForwardPointer: 255 return String("TypeForwardPointer"); 256 case SpvOpConstantTrue: 257 return String("ConstantTrue"); 258 case SpvOpConstantFalse: 259 return String("ConstantFalse"); 260 case SpvOpConstant: 261 return String("Constant"); 262 case SpvOpConstantComposite: 263 return String("ConstantComposite"); 264 case SpvOpConstantSampler: 265 return String("ConstantSampler"); 266 case SpvOpConstantNull: 267 return String("ConstantNull"); 268 case SpvOpSpecConstantTrue: 269 return String("SpecConstantTrue"); 270 case SpvOpSpecConstantFalse: 271 return String("SpecConstantFalse"); 272 case SpvOpSpecConstant: 273 return String("SpecConstant"); 274 case SpvOpSpecConstantComposite: 275 return String("SpecConstantComposite"); 276 case SpvOpSpecConstantOp: 277 return String("SpecConstantOp"); 278 case SpvOpFunction: 279 return String("Function"); 280 case SpvOpFunctionParameter: 281 return String("FunctionParameter"); 282 case SpvOpFunctionEnd: 283 return String("FunctionEnd"); 284 case SpvOpFunctionCall: 285 return String("FunctionCall"); 286 case SpvOpVariable: 287 return String("Variable"); 288 case SpvOpImageTexelPointer: 289 return String("ImageTexelPointer"); 290 case SpvOpLoad: 291 return String("Load"); 292 case SpvOpStore: 293 return String("Store"); 294 case SpvOpCopyMemory: 295 return String("CopyMemory"); 296 case SpvOpCopyMemorySized: 297 return String("CopyMemorySized"); 298 case SpvOpAccessChain: 299 return String("AccessChain"); 300 case SpvOpInBoundsAccessChain: 301 return String("InBoundsAccessChain"); 302 case SpvOpPtrAccessChain: 303 return String("PtrAccessChain"); 304 case SpvOpArrayLength: 305 return String("ArrayLength"); 306 case SpvOpGenericPtrMemSemantics: 307 return String("GenericPtrMemSemantics"); 308 case SpvOpInBoundsPtrAccessChain: 309 return String("InBoundsPtrAccessChain"); 310 case SpvOpDecorate: 311 return String("Decorate"); 312 case SpvOpMemberDecorate: 313 return String("MemberDecorate"); 314 case SpvOpDecorationGroup: 315 return String("DecorationGroup"); 316 case SpvOpGroupDecorate: 317 return String("GroupDecorate"); 318 case SpvOpGroupMemberDecorate: 319 return String("GroupMemberDecorate"); 320 case SpvOpVectorExtractDynamic: 321 return String("VectorExtractDynamic"); 322 case SpvOpVectorInsertDynamic: 323 return String("VectorInsertDynamic"); 324 case SpvOpVectorShuffle: 325 return String("VectorShuffle"); 326 case SpvOpCompositeConstruct: 327 return String("CompositeConstruct"); 328 case SpvOpCompositeExtract: 329 return String("CompositeExtract"); 330 case SpvOpCompositeInsert: 331 return String("CompositeInsert"); 332 case SpvOpCopyObject: 333 return String("CopyObject"); 334 case SpvOpTranspose: 335 return String("Transpose"); 336 case SpvOpSampledImage: 337 return String("SampledImage"); 338 case SpvOpImageSampleImplicitLod: 339 return String("ImageSampleImplicitLod"); 340 case SpvOpImageSampleExplicitLod: 341 return String("ImageSampleExplicitLod"); 342 case SpvOpImageSampleDrefImplicitLod: 343 return String("ImageSampleDrefImplicitLod"); 344 case SpvOpImageSampleDrefExplicitLod: 345 return String("ImageSampleDrefExplicitLod"); 346 case SpvOpImageSampleProjImplicitLod: 347 return String("ImageSampleProjImplicitLod"); 348 case SpvOpImageSampleProjExplicitLod: 349 return String("ImageSampleProjExplicitLod"); 350 case SpvOpImageSampleProjDrefImplicitLod: 351 return String("ImageSampleProjDrefImplicitLod"); 352 case SpvOpImageSampleProjDrefExplicitLod: 353 return String("ImageSampleProjDrefExplicitLod"); 354 case SpvOpImageFetch: 355 return String("ImageFetch"); 356 case SpvOpImageGather: 357 return String("ImageGather"); 358 case SpvOpImageDrefGather: 359 return String("ImageDrefGather"); 360 case SpvOpImageRead: 361 return String("ImageRead"); 362 case SpvOpImageWrite: 363 return String("ImageWrite"); 364 case SpvOpImage: 365 return String("Image"); 366 case SpvOpImageQueryFormat: 367 return String("ImageQueryFormat"); 368 case SpvOpImageQueryOrder: 369 return String("ImageQueryOrder"); 370 case SpvOpImageQuerySizeLod: 371 return String("ImageQuerySizeLod"); 372 case SpvOpImageQuerySize: 373 return String("ImageQuerySize"); 374 case SpvOpImageQueryLod: 375 return String("ImageQueryLod"); 376 case SpvOpImageQueryLevels: 377 return String("ImageQueryLevels"); 378 case SpvOpImageQuerySamples: 379 return String("ImageQuerySamples"); 380 case SpvOpConvertFToU: 381 return String("ConvertFToU"); 382 case SpvOpConvertFToS: 383 return String("ConvertFToS"); 384 case SpvOpConvertSToF: 385 return String("ConvertSToF"); 386 case SpvOpConvertUToF: 387 return String("ConvertUToF"); 388 case SpvOpUConvert: 389 return String("UConvert"); 390 case SpvOpSConvert: 391 return String("SConvert"); 392 case SpvOpFConvert: 393 return String("FConvert"); 394 case SpvOpQuantizeToF16: 395 return String("QuantizeToF16"); 396 case SpvOpConvertPtrToU: 397 return String("ConvertPtrToU"); 398 case SpvOpSatConvertSToU: 399 return String("SatConvertSToU"); 400 case SpvOpSatConvertUToS: 401 return String("SatConvertUToS"); 402 case SpvOpConvertUToPtr: 403 return String("ConvertUToPtr"); 404 case SpvOpPtrCastToGeneric: 405 return String("PtrCastToGeneric"); 406 case SpvOpGenericCastToPtr: 407 return String("GenericCastToPtr"); 408 case SpvOpGenericCastToPtrExplicit: 409 return String("GenericCastToPtrExplicit"); 410 case SpvOpBitcast: 411 return String("Bitcast"); 412 case SpvOpSNegate: 413 return String("SNegate"); 414 case SpvOpFNegate: 415 return String("FNegate"); 416 case SpvOpIAdd: 417 return String("IAdd"); 418 case SpvOpFAdd: 419 return String("FAdd"); 420 case SpvOpISub: 421 return String("ISub"); 422 case SpvOpFSub: 423 return String("FSub"); 424 case SpvOpIMul: 425 return String("IMul"); 426 case SpvOpFMul: 427 return String("FMul"); 428 case SpvOpUDiv: 429 return String("UDiv"); 430 case SpvOpSDiv: 431 return String("SDiv"); 432 case SpvOpFDiv: 433 return String("FDiv"); 434 case SpvOpUMod: 435 return String("UMod"); 436 case SpvOpSRem: 437 return String("SRem"); 438 case SpvOpSMod: 439 return String("SMod"); 440 case SpvOpFRem: 441 return String("FRem"); 442 case SpvOpFMod: 443 return String("FMod"); 444 case SpvOpVectorTimesScalar: 445 return String("VectorTimesScalar"); 446 case SpvOpMatrixTimesScalar: 447 return String("MatrixTimesScalar"); 448 case SpvOpVectorTimesMatrix: 449 return String("VectorTimesMatrix"); 450 case SpvOpMatrixTimesVector: 451 return String("MatrixTimesVector"); 452 case SpvOpMatrixTimesMatrix: 453 return String("MatrixTimesMatrix"); 454 case SpvOpOuterProduct: 455 return String("OuterProduct"); 456 case SpvOpDot: 457 return String("Dot"); 458 case SpvOpIAddCarry: 459 return String("IAddCarry"); 460 case SpvOpISubBorrow: 461 return String("ISubBorrow"); 462 case SpvOpUMulExtended: 463 return String("UMulExtended"); 464 case SpvOpSMulExtended: 465 return String("SMulExtended"); 466 case SpvOpAny: 467 return String("Any"); 468 case SpvOpAll: 469 return String("All"); 470 case SpvOpIsNan: 471 return String("IsNan"); 472 case SpvOpIsInf: 473 return String("IsInf"); 474 case SpvOpIsFinite: 475 return String("IsFinite"); 476 case SpvOpIsNormal: 477 return String("IsNormal"); 478 case SpvOpSignBitSet: 479 return String("SignBitSet"); 480 case SpvOpLessOrGreater: 481 return String("LessOrGreater"); 482 case SpvOpOrdered: 483 return String("Ordered"); 484 case SpvOpUnordered: 485 return String("Unordered"); 486 case SpvOpLogicalEqual: 487 return String("LogicalEqual"); 488 case SpvOpLogicalNotEqual: 489 return String("LogicalNotEqual"); 490 case SpvOpLogicalOr: 491 return String("LogicalOr"); 492 case SpvOpLogicalAnd: 493 return String("LogicalAnd"); 494 case SpvOpLogicalNot: 495 return String("LogicalNot"); 496 case SpvOpSelect: 497 return String("Select"); 498 case SpvOpIEqual: 499 return String("IEqual"); 500 case SpvOpINotEqual: 501 return String("INotEqual"); 502 case SpvOpUGreaterThan: 503 return String("UGreaterThan"); 504 case SpvOpSGreaterThan: 505 return String("SGreaterThan"); 506 case SpvOpUGreaterThanEqual: 507 return String("UGreaterThanEqual"); 508 case SpvOpSGreaterThanEqual: 509 return String("SGreaterThanEqual"); 510 case SpvOpULessThan: 511 return String("ULessThan"); 512 case SpvOpSLessThan: 513 return String("SLessThan"); 514 case SpvOpULessThanEqual: 515 return String("ULessThanEqual"); 516 case SpvOpSLessThanEqual: 517 return String("SLessThanEqual"); 518 case SpvOpFOrdEqual: 519 return String("FOrdEqual"); 520 case SpvOpFUnordEqual: 521 return String("FUnordEqual"); 522 case SpvOpFOrdNotEqual: 523 return String("FOrdNotEqual"); 524 case SpvOpFUnordNotEqual: 525 return String("FUnordNotEqual"); 526 case SpvOpFOrdLessThan: 527 return String("FOrdLessThan"); 528 case SpvOpFUnordLessThan: 529 return String("FUnordLessThan"); 530 case SpvOpFOrdGreaterThan: 531 return String("FOrdGreaterThan"); 532 case SpvOpFUnordGreaterThan: 533 return String("FUnordGreaterThan"); 534 case SpvOpFOrdLessThanEqual: 535 return String("FOrdLessThanEqual"); 536 case SpvOpFUnordLessThanEqual: 537 return String("FUnordLessThanEqual"); 538 case SpvOpFOrdGreaterThanEqual: 539 return String("FOrdGreaterThanEqual"); 540 case SpvOpFUnordGreaterThanEqual: 541 return String("FUnordGreaterThanEqual"); 542 case SpvOpShiftRightLogical: 543 return String("ShiftRightLogical"); 544 case SpvOpShiftRightArithmetic: 545 return String("ShiftRightArithmetic"); 546 case SpvOpShiftLeftLogical: 547 return String("ShiftLeftLogical"); 548 case SpvOpBitwiseOr: 549 return String("BitwiseOr"); 550 case SpvOpBitwiseXor: 551 return String("BitwiseXor"); 552 case SpvOpBitwiseAnd: 553 return String("BitwiseAnd"); 554 case SpvOpNot: 555 return String("Not"); 556 case SpvOpBitFieldInsert: 557 return String("BitFieldInsert"); 558 case SpvOpBitFieldSExtract: 559 return String("BitFieldSExtract"); 560 case SpvOpBitFieldUExtract: 561 return String("BitFieldUExtract"); 562 case SpvOpBitReverse: 563 return String("BitReverse"); 564 case SpvOpBitCount: 565 return String("BitCount"); 566 case SpvOpDPdx: 567 return String("DPdx"); 568 case SpvOpDPdy: 569 return String("DPdy"); 570 case SpvOpFwidth: 571 return String("Fwidth"); 572 case SpvOpDPdxFine: 573 return String("DPdxFine"); 574 case SpvOpDPdyFine: 575 return String("DPdyFine"); 576 case SpvOpFwidthFine: 577 return String("FwidthFine"); 578 case SpvOpDPdxCoarse: 579 return String("DPdxCoarse"); 580 case SpvOpDPdyCoarse: 581 return String("DPdyCoarse"); 582 case SpvOpFwidthCoarse: 583 return String("FwidthCoarse"); 584 case SpvOpEmitVertex: 585 return String("EmitVertex"); 586 case SpvOpEndPrimitive: 587 return String("EndPrimitive"); 588 case SpvOpEmitStreamVertex: 589 return String("EmitStreamVertex"); 590 case SpvOpEndStreamPrimitive: 591 return String("EndStreamPrimitive"); 592 case SpvOpControlBarrier: 593 return String("ControlBarrier"); 594 case SpvOpMemoryBarrier: 595 return String("MemoryBarrier"); 596 case SpvOpAtomicLoad: 597 return String("AtomicLoad"); 598 case SpvOpAtomicStore: 599 return String("AtomicStore"); 600 case SpvOpAtomicExchange: 601 return String("AtomicExchange"); 602 case SpvOpAtomicCompareExchange: 603 return String("AtomicCompareExchange"); 604 case SpvOpAtomicCompareExchangeWeak: 605 return String("AtomicCompareExchangeWeak"); 606 case SpvOpAtomicIIncrement: 607 return String("AtomicIIncrement"); 608 case SpvOpAtomicIDecrement: 609 return String("AtomicIDecrement"); 610 case SpvOpAtomicIAdd: 611 return String("AtomicIAdd"); 612 case SpvOpAtomicISub: 613 return String("AtomicISub"); 614 case SpvOpAtomicSMin: 615 return String("AtomicSMin"); 616 case SpvOpAtomicUMin: 617 return String("AtomicUMin"); 618 case SpvOpAtomicSMax: 619 return String("AtomicSMax"); 620 case SpvOpAtomicUMax: 621 return String("AtomicUMax"); 622 case SpvOpAtomicAnd: 623 return String("AtomicAnd"); 624 case SpvOpAtomicOr: 625 return String("AtomicOr"); 626 case SpvOpAtomicXor: 627 return String("AtomicXor"); 628 case SpvOpPhi: 629 return String("Phi"); 630 case SpvOpLoopMerge: 631 return String("LoopMerge"); 632 case SpvOpSelectionMerge: 633 return String("SelectionMerge"); 634 case SpvOpLabel: 635 return String("Label"); 636 case SpvOpBranch: 637 return String("Branch"); 638 case SpvOpBranchConditional: 639 return String("BranchConditional"); 640 case SpvOpSwitch: 641 return String("Switch"); 642 case SpvOpKill: 643 return String("Kill"); 644 case SpvOpReturn: 645 return String("Return"); 646 case SpvOpReturnValue: 647 return String("ReturnValue"); 648 case SpvOpUnreachable: 649 return String("Unreachable"); 650 case SpvOpLifetimeStart: 651 return String("LifetimeStart"); 652 case SpvOpLifetimeStop: 653 return String("LifetimeStop"); 654 case SpvOpGroupAsyncCopy: 655 return String("GroupAsyncCopy"); 656 case SpvOpGroupWaitEvents: 657 return String("GroupWaitEvents"); 658 case SpvOpGroupAll: 659 return String("GroupAll"); 660 case SpvOpGroupAny: 661 return String("GroupAny"); 662 case SpvOpGroupBroadcast: 663 return String("GroupBroadcast"); 664 case SpvOpGroupIAdd: 665 return String("GroupIAdd"); 666 case SpvOpGroupFAdd: 667 return String("GroupFAdd"); 668 case SpvOpGroupFMin: 669 return String("GroupFMin"); 670 case SpvOpGroupUMin: 671 return String("GroupUMin"); 672 case SpvOpGroupSMin: 673 return String("GroupSMin"); 674 case SpvOpGroupFMax: 675 return String("GroupFMax"); 676 case SpvOpGroupUMax: 677 return String("GroupUMax"); 678 case SpvOpGroupSMax: 679 return String("GroupSMax"); 680 case SpvOpReadPipe: 681 return String("ReadPipe"); 682 case SpvOpWritePipe: 683 return String("WritePipe"); 684 case SpvOpReservedReadPipe: 685 return String("ReservedReadPipe"); 686 case SpvOpReservedWritePipe: 687 return String("ReservedWritePipe"); 688 case SpvOpReserveReadPipePackets: 689 return String("ReserveReadPipePackets"); 690 case SpvOpReserveWritePipePackets: 691 return String("ReserveWritePipePackets"); 692 case SpvOpCommitReadPipe: 693 return String("CommitReadPipe"); 694 case SpvOpCommitWritePipe: 695 return String("CommitWritePipe"); 696 case SpvOpIsValidReserveId: 697 return String("IsValidReserveId"); 698 case SpvOpGetNumPipePackets: 699 return String("GetNumPipePackets"); 700 case SpvOpGetMaxPipePackets: 701 return String("GetMaxPipePackets"); 702 case SpvOpGroupReserveReadPipePackets: 703 return String("GroupReserveReadPipePackets"); 704 case SpvOpGroupReserveWritePipePackets: 705 return String("GroupReserveWritePipePackets"); 706 case SpvOpGroupCommitReadPipe: 707 return String("GroupCommitReadPipe"); 708 case SpvOpGroupCommitWritePipe: 709 return String("GroupCommitWritePipe"); 710 case SpvOpEnqueueMarker: 711 return String("EnqueueMarker"); 712 case SpvOpEnqueueKernel: 713 return String("EnqueueKernel"); 714 case SpvOpGetKernelNDrangeSubGroupCount: 715 return String("GetKernelNDrangeSubGroupCount"); 716 case SpvOpGetKernelNDrangeMaxSubGroupSize: 717 return String("GetKernelNDrangeMaxSubGroupSize"); 718 case SpvOpGetKernelWorkGroupSize: 719 return String("GetKernelWorkGroupSize"); 720 case SpvOpGetKernelPreferredWorkGroupSizeMultiple: 721 return String("GetKernelPreferredWorkGroupSizeMultiple"); 722 case SpvOpRetainEvent: 723 return String("RetainEvent"); 724 case SpvOpReleaseEvent: 725 return String("ReleaseEvent"); 726 case SpvOpCreateUserEvent: 727 return String("CreateUserEvent"); 728 case SpvOpIsValidEvent: 729 return String("IsValidEvent"); 730 case SpvOpSetUserEventStatus: 731 return String("SetUserEventStatus"); 732 case SpvOpCaptureEventProfilingInfo: 733 return String("CaptureEventProfilingInfo"); 734 case SpvOpGetDefaultQueue: 735 return String("GetDefaultQueue"); 736 case SpvOpBuildNDRange: 737 return String("BuildNDRange"); 738 case SpvOpImageSparseSampleImplicitLod: 739 return String("ImageSparseSampleImplicitLod"); 740 case SpvOpImageSparseSampleExplicitLod: 741 return String("ImageSparseSampleExplicitLod"); 742 case SpvOpImageSparseSampleDrefImplicitLod: 743 return String("ImageSparseSampleDrefImplicitLod"); 744 case SpvOpImageSparseSampleDrefExplicitLod: 745 return String("ImageSparseSampleDrefExplicitLod"); 746 case SpvOpImageSparseSampleProjImplicitLod: 747 return String("ImageSparseSampleProjImplicitLod"); 748 case SpvOpImageSparseSampleProjExplicitLod: 749 return String("ImageSparseSampleProjExplicitLod"); 750 case SpvOpImageSparseSampleProjDrefImplicitLod: 751 return String("ImageSparseSampleProjDrefImplicitLod"); 752 case SpvOpImageSparseSampleProjDrefExplicitLod: 753 return String("ImageSparseSampleProjDrefExplicitLod"); 754 case SpvOpImageSparseFetch: 755 return String("ImageSparseFetch"); 756 case SpvOpImageSparseGather: 757 return String("ImageSparseGather"); 758 case SpvOpImageSparseDrefGather: 759 return String("ImageSparseDrefGather"); 760 case SpvOpImageSparseTexelsResident: 761 return String("ImageSparseTexelsResident"); 762 case SpvOpNoLine: 763 return String("NoLine"); 764 case SpvOpAtomicFlagTestAndSet: 765 return String("AtomicFlagTestAndSet"); 766 case SpvOpAtomicFlagClear: 767 return String("AtomicFlagClear"); 768 case SpvOpImageSparseRead: 769 return String("ImageSparseRead"); 770 default: 771 ABORT("unsupported SPIR-V op"); 772 } 773 } 774 #endif 775 776 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) { 777 ASSERT(opCode != SpvOpUndef); 778 switch (opCode) { 779 case SpvOpReturn: // fall through 780 case SpvOpReturnValue: // fall through 781 case SpvOpKill: // fall through 782 case SpvOpBranch: // fall through 783 case SpvOpBranchConditional: 784 ASSERT(fCurrentBlock); 785 fCurrentBlock = 0; 786 break; 787 case SpvOpConstant: // fall through 788 case SpvOpConstantTrue: // fall through 789 case SpvOpConstantFalse: // fall through 790 case SpvOpConstantComposite: // fall through 791 case SpvOpTypeVoid: // fall through 792 case SpvOpTypeInt: // fall through 793 case SpvOpTypeFloat: // fall through 794 case SpvOpTypeBool: // fall through 795 case SpvOpTypeVector: // fall through 796 case SpvOpTypeMatrix: // fall through 797 case SpvOpTypeArray: // fall through 798 case SpvOpTypePointer: // fall through 799 case SpvOpTypeFunction: // fall through 800 case SpvOpTypeRuntimeArray: // fall through 801 case SpvOpTypeStruct: // fall through 802 case SpvOpTypeImage: // fall through 803 case SpvOpTypeSampledImage: // fall through 804 case SpvOpVariable: // fall through 805 case SpvOpFunction: // fall through 806 case SpvOpFunctionParameter: // fall through 807 case SpvOpFunctionEnd: // fall through 808 case SpvOpExecutionMode: // fall through 809 case SpvOpMemoryModel: // fall through 810 case SpvOpCapability: // fall through 811 case SpvOpExtInstImport: // fall through 812 case SpvOpEntryPoint: // fall through 813 case SpvOpSource: // fall through 814 case SpvOpSourceExtension: // fall through 815 case SpvOpName: // fall through 816 case SpvOpMemberName: // fall through 817 case SpvOpDecorate: // fall through 818 case SpvOpMemberDecorate: 819 break; 820 default: 821 ASSERT(fCurrentBlock); 822 } 823 #if SPIRV_DEBUG 824 out << std::endl << opcode_text(opCode) << " "; 825 #else 826 this->writeWord((length << 16) | opCode, out); 827 #endif 828 } 829 830 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) { 831 fCurrentBlock = label; 832 this->writeInstruction(SpvOpLabel, label, out); 833 } 834 835 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) { 836 this->writeOpCode(opCode, 1, out); 837 } 838 839 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) { 840 this->writeOpCode(opCode, 2, out); 841 this->writeWord(word1, out); 842 } 843 844 void SPIRVCodeGenerator::writeString(const char* string, OutputStream& out) { 845 size_t length = strlen(string); 846 out.write(string, length); 847 switch (length % 4) { 848 case 1: 849 out.write8(0); 850 // fall through 851 case 2: 852 out.write8(0); 853 // fall through 854 case 3: 855 out.write8(0); 856 break; 857 default: 858 this->writeWord(0, out); 859 } 860 } 861 862 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, const char* string, OutputStream& out) { 863 int32_t length = (int32_t) strlen(string); 864 this->writeOpCode(opCode, 1 + (length + 4) / 4, out); 865 this->writeString(string, out); 866 } 867 868 869 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, const char* string, 870 OutputStream& out) { 871 int32_t length = (int32_t) strlen(string); 872 this->writeOpCode(opCode, 2 + (length + 4) / 4, out); 873 this->writeWord(word1, out); 874 this->writeString(string, out); 875 } 876 877 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 878 const char* string, OutputStream& out) { 879 int32_t length = (int32_t) strlen(string); 880 this->writeOpCode(opCode, 3 + (length + 4) / 4, out); 881 this->writeWord(word1, out); 882 this->writeWord(word2, out); 883 this->writeString(string, out); 884 } 885 886 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 887 OutputStream& out) { 888 this->writeOpCode(opCode, 3, out); 889 this->writeWord(word1, out); 890 this->writeWord(word2, out); 891 } 892 893 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 894 int32_t word3, OutputStream& out) { 895 this->writeOpCode(opCode, 4, out); 896 this->writeWord(word1, out); 897 this->writeWord(word2, out); 898 this->writeWord(word3, out); 899 } 900 901 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 902 int32_t word3, int32_t word4, OutputStream& out) { 903 this->writeOpCode(opCode, 5, out); 904 this->writeWord(word1, out); 905 this->writeWord(word2, out); 906 this->writeWord(word3, out); 907 this->writeWord(word4, out); 908 } 909 910 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 911 int32_t word3, int32_t word4, int32_t word5, 912 OutputStream& out) { 913 this->writeOpCode(opCode, 6, out); 914 this->writeWord(word1, out); 915 this->writeWord(word2, out); 916 this->writeWord(word3, out); 917 this->writeWord(word4, out); 918 this->writeWord(word5, out); 919 } 920 921 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 922 int32_t word3, int32_t word4, int32_t word5, 923 int32_t word6, OutputStream& out) { 924 this->writeOpCode(opCode, 7, out); 925 this->writeWord(word1, out); 926 this->writeWord(word2, out); 927 this->writeWord(word3, out); 928 this->writeWord(word4, out); 929 this->writeWord(word5, out); 930 this->writeWord(word6, out); 931 } 932 933 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 934 int32_t word3, int32_t word4, int32_t word5, 935 int32_t word6, int32_t word7, OutputStream& out) { 936 this->writeOpCode(opCode, 8, out); 937 this->writeWord(word1, out); 938 this->writeWord(word2, out); 939 this->writeWord(word3, out); 940 this->writeWord(word4, out); 941 this->writeWord(word5, out); 942 this->writeWord(word6, out); 943 this->writeWord(word7, out); 944 } 945 946 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 947 int32_t word3, int32_t word4, int32_t word5, 948 int32_t word6, int32_t word7, int32_t word8, 949 OutputStream& out) { 950 this->writeOpCode(opCode, 9, out); 951 this->writeWord(word1, out); 952 this->writeWord(word2, out); 953 this->writeWord(word3, out); 954 this->writeWord(word4, out); 955 this->writeWord(word5, out); 956 this->writeWord(word6, out); 957 this->writeWord(word7, out); 958 this->writeWord(word8, out); 959 } 960 961 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) { 962 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) { 963 if (fCapabilities & bit) { 964 this->writeInstruction(SpvOpCapability, (SpvId) i, out); 965 } 966 } 967 } 968 969 SpvId SPIRVCodeGenerator::nextId() { 970 return fIdCount++; 971 } 972 973 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout, 974 SpvId resultId) { 975 this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer); 976 // go ahead and write all of the field types, so we don't inadvertently write them while we're 977 // in the middle of writing the struct instruction 978 std::vector<SpvId> types; 979 for (const auto& f : type.fields()) { 980 types.push_back(this->getType(*f.fType, memoryLayout)); 981 } 982 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer); 983 this->writeWord(resultId, fConstantBuffer); 984 for (SpvId id : types) { 985 this->writeWord(id, fConstantBuffer); 986 } 987 size_t offset = 0; 988 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) { 989 size_t size = memoryLayout.size(*type.fields()[i].fType); 990 size_t alignment = memoryLayout.alignment(*type.fields()[i].fType); 991 const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout; 992 if (fieldLayout.fOffset >= 0) { 993 if (fieldLayout.fOffset < (int) offset) { 994 fErrors.error(type.fPosition, 995 "offset of field '" + type.fields()[i].fName + "' must be at " 996 "least " + to_string((int) offset)); 997 } 998 if (fieldLayout.fOffset % alignment) { 999 fErrors.error(type.fPosition, 1000 "offset of field '" + type.fields()[i].fName + "' must be a multiple" 1001 " of " + to_string((int) alignment)); 1002 } 1003 offset = fieldLayout.fOffset; 1004 } else { 1005 size_t mod = offset % alignment; 1006 if (mod) { 1007 offset += alignment - mod; 1008 } 1009 } 1010 this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName.c_str(), 1011 fNameBuffer); 1012 this->writeLayout(fieldLayout, resultId, i); 1013 if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) { 1014 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset, 1015 (SpvId) offset, fDecorationBuffer); 1016 } 1017 if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) { 1018 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor, 1019 fDecorationBuffer); 1020 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride, 1021 (SpvId) memoryLayout.stride(*type.fields()[i].fType), 1022 fDecorationBuffer); 1023 } 1024 offset += size; 1025 Type::Kind kind = type.fields()[i].fType->kind(); 1026 if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) { 1027 offset += alignment - offset % alignment; 1028 } 1029 } 1030 } 1031 1032 SpvId SPIRVCodeGenerator::getType(const Type& type) { 1033 return this->getType(type, fDefaultLayout); 1034 } 1035 1036 SpvId SPIRVCodeGenerator::getType(const Type& type, const MemoryLayout& layout) { 1037 String key = type.name() + to_string((int) layout.fStd); 1038 auto entry = fTypeMap.find(key); 1039 if (entry == fTypeMap.end()) { 1040 SpvId result = this->nextId(); 1041 switch (type.kind()) { 1042 case Type::kScalar_Kind: 1043 if (type == *fContext.fBool_Type) { 1044 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer); 1045 } else if (type == *fContext.fInt_Type) { 1046 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer); 1047 } else if (type == *fContext.fUInt_Type) { 1048 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer); 1049 } else if (type == *fContext.fFloat_Type) { 1050 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer); 1051 } else if (type == *fContext.fDouble_Type) { 1052 this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer); 1053 } else { 1054 ASSERT(false); 1055 } 1056 break; 1057 case Type::kVector_Kind: 1058 this->writeInstruction(SpvOpTypeVector, result, 1059 this->getType(type.componentType(), layout), 1060 type.columns(), fConstantBuffer); 1061 break; 1062 case Type::kMatrix_Kind: 1063 this->writeInstruction(SpvOpTypeMatrix, result, 1064 this->getType(index_type(fContext, type), layout), 1065 type.columns(), fConstantBuffer); 1066 break; 1067 case Type::kStruct_Kind: 1068 this->writeStruct(type, layout, result); 1069 break; 1070 case Type::kArray_Kind: { 1071 if (type.columns() > 0) { 1072 IntLiteral count(fContext, Position(), type.columns()); 1073 this->writeInstruction(SpvOpTypeArray, result, 1074 this->getType(type.componentType(), layout), 1075 this->writeIntLiteral(count), fConstantBuffer); 1076 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 1077 (int32_t) layout.stride(type), 1078 fDecorationBuffer); 1079 } else { 1080 this->writeInstruction(SpvOpTypeRuntimeArray, result, 1081 this->getType(type.componentType(), layout), 1082 fConstantBuffer); 1083 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 1084 (int32_t) layout.stride(type), 1085 fDecorationBuffer); 1086 } 1087 break; 1088 } 1089 case Type::kSampler_Kind: { 1090 SpvId image = result; 1091 if (SpvDimSubpassData != type.dimensions()) { 1092 image = this->nextId(); 1093 } 1094 if (SpvDimBuffer == type.dimensions()) { 1095 fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer); 1096 } 1097 this->writeInstruction(SpvOpTypeImage, image, 1098 this->getType(*fContext.fFloat_Type, layout), 1099 type.dimensions(), type.isDepth(), type.isArrayed(), 1100 type.isMultisampled(), type.isSampled() ? 1 : 2, 1101 SpvImageFormatUnknown, fConstantBuffer); 1102 fImageTypeMap[key] = image; 1103 if (SpvDimSubpassData != type.dimensions()) { 1104 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer); 1105 } 1106 break; 1107 } 1108 default: 1109 if (type == *fContext.fVoid_Type) { 1110 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer); 1111 } else { 1112 ABORT("invalid type: %s", type.description().c_str()); 1113 } 1114 } 1115 fTypeMap[key] = result; 1116 return result; 1117 } 1118 return entry->second; 1119 } 1120 1121 SpvId SPIRVCodeGenerator::getImageType(const Type& type) { 1122 ASSERT(type.kind() == Type::kSampler_Kind); 1123 this->getType(type); 1124 String key = type.name() + to_string((int) fDefaultLayout.fStd); 1125 ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end()); 1126 return fImageTypeMap[key]; 1127 } 1128 1129 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) { 1130 String key = function.fReturnType.description() + "("; 1131 String separator; 1132 for (size_t i = 0; i < function.fParameters.size(); i++) { 1133 key += separator; 1134 separator = ", "; 1135 key += function.fParameters[i]->fType.description(); 1136 } 1137 key += ")"; 1138 auto entry = fTypeMap.find(key); 1139 if (entry == fTypeMap.end()) { 1140 SpvId result = this->nextId(); 1141 int32_t length = 3 + (int32_t) function.fParameters.size(); 1142 SpvId returnType = this->getType(function.fReturnType); 1143 std::vector<SpvId> parameterTypes; 1144 for (size_t i = 0; i < function.fParameters.size(); i++) { 1145 // glslang seems to treat all function arguments as pointers whether they need to be or 1146 // not. I was initially puzzled by this until I ran bizarre failures with certain 1147 // patterns of function calls and control constructs, as exemplified by this minimal 1148 // failure case: 1149 // 1150 // void sphere(float x) { 1151 // } 1152 // 1153 // void map() { 1154 // sphere(1.0); 1155 // } 1156 // 1157 // void main() { 1158 // for (int i = 0; i < 1; i++) { 1159 // map(); 1160 // } 1161 // } 1162 // 1163 // As of this writing, compiling this in the "obvious" way (with sphere taking a float) 1164 // crashes. Making it take a float* and storing the argument in a temporary variable, 1165 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of 1166 // the spec makes this make sense. 1167 // if (is_out(function->fParameters[i])) { 1168 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType, 1169 SpvStorageClassFunction)); 1170 // } else { 1171 // parameterTypes.push_back(this->getType(function.fParameters[i]->fType)); 1172 // } 1173 } 1174 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer); 1175 this->writeWord(result, fConstantBuffer); 1176 this->writeWord(returnType, fConstantBuffer); 1177 for (SpvId id : parameterTypes) { 1178 this->writeWord(id, fConstantBuffer); 1179 } 1180 fTypeMap[key] = result; 1181 return result; 1182 } 1183 return entry->second; 1184 } 1185 1186 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) { 1187 return this->getPointerType(type, fDefaultLayout, storageClass); 1188 } 1189 1190 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, const MemoryLayout& layout, 1191 SpvStorageClass_ storageClass) { 1192 String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass); 1193 auto entry = fTypeMap.find(key); 1194 if (entry == fTypeMap.end()) { 1195 SpvId result = this->nextId(); 1196 this->writeInstruction(SpvOpTypePointer, result, storageClass, 1197 this->getType(type), fConstantBuffer); 1198 fTypeMap[key] = result; 1199 return result; 1200 } 1201 return entry->second; 1202 } 1203 1204 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) { 1205 switch (expr.fKind) { 1206 case Expression::kBinary_Kind: 1207 return this->writeBinaryExpression((BinaryExpression&) expr, out); 1208 case Expression::kBoolLiteral_Kind: 1209 return this->writeBoolLiteral((BoolLiteral&) expr); 1210 case Expression::kConstructor_Kind: 1211 return this->writeConstructor((Constructor&) expr, out); 1212 case Expression::kIntLiteral_Kind: 1213 return this->writeIntLiteral((IntLiteral&) expr); 1214 case Expression::kFieldAccess_Kind: 1215 return this->writeFieldAccess(((FieldAccess&) expr), out); 1216 case Expression::kFloatLiteral_Kind: 1217 return this->writeFloatLiteral(((FloatLiteral&) expr)); 1218 case Expression::kFunctionCall_Kind: 1219 return this->writeFunctionCall((FunctionCall&) expr, out); 1220 case Expression::kPrefix_Kind: 1221 return this->writePrefixExpression((PrefixExpression&) expr, out); 1222 case Expression::kPostfix_Kind: 1223 return this->writePostfixExpression((PostfixExpression&) expr, out); 1224 case Expression::kSwizzle_Kind: 1225 return this->writeSwizzle((Swizzle&) expr, out); 1226 case Expression::kVariableReference_Kind: 1227 return this->writeVariableReference((VariableReference&) expr, out); 1228 case Expression::kTernary_Kind: 1229 return this->writeTernaryExpression((TernaryExpression&) expr, out); 1230 case Expression::kIndex_Kind: 1231 return this->writeIndexExpression((IndexExpression&) expr, out); 1232 default: 1233 ABORT("unsupported expression: %s", expr.description().c_str()); 1234 } 1235 return -1; 1236 } 1237 1238 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) { 1239 auto intrinsic = fIntrinsicMap.find(c.fFunction.fName); 1240 ASSERT(intrinsic != fIntrinsicMap.end()); 1241 const Type& type = c.fArguments[0]->fType; 1242 int32_t intrinsicId; 1243 if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) { 1244 intrinsicId = std::get<1>(intrinsic->second); 1245 } else if (is_signed(fContext, type)) { 1246 intrinsicId = std::get<2>(intrinsic->second); 1247 } else if (is_unsigned(fContext, type)) { 1248 intrinsicId = std::get<3>(intrinsic->second); 1249 } else if (is_bool(fContext, type)) { 1250 intrinsicId = std::get<4>(intrinsic->second); 1251 } else { 1252 ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(), 1253 type.description().c_str()); 1254 } 1255 switch (std::get<0>(intrinsic->second)) { 1256 case kGLSL_STD_450_IntrinsicKind: { 1257 SpvId result = this->nextId(); 1258 std::vector<SpvId> arguments; 1259 for (size_t i = 0; i < c.fArguments.size(); i++) { 1260 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1261 } 1262 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); 1263 this->writeWord(this->getType(c.fType), out); 1264 this->writeWord(result, out); 1265 this->writeWord(fGLSLExtendedInstructions, out); 1266 this->writeWord(intrinsicId, out); 1267 for (SpvId id : arguments) { 1268 this->writeWord(id, out); 1269 } 1270 return result; 1271 } 1272 case kSPIRV_IntrinsicKind: { 1273 SpvId result = this->nextId(); 1274 std::vector<SpvId> arguments; 1275 for (size_t i = 0; i < c.fArguments.size(); i++) { 1276 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1277 } 1278 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out); 1279 this->writeWord(this->getType(c.fType), out); 1280 this->writeWord(result, out); 1281 for (SpvId id : arguments) { 1282 this->writeWord(id, out); 1283 } 1284 return result; 1285 } 1286 case kSpecial_IntrinsicKind: 1287 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out); 1288 default: 1289 ABORT("unsupported intrinsic kind"); 1290 } 1291 } 1292 1293 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, 1294 OutputStream& out) { 1295 SpvId result = this->nextId(); 1296 switch (kind) { 1297 case kAtan_SpecialIntrinsic: { 1298 std::vector<SpvId> arguments; 1299 for (size_t i = 0; i < c.fArguments.size(); i++) { 1300 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1301 } 1302 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); 1303 this->writeWord(this->getType(c.fType), out); 1304 this->writeWord(result, out); 1305 this->writeWord(fGLSLExtendedInstructions, out); 1306 this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out); 1307 for (SpvId id : arguments) { 1308 this->writeWord(id, out); 1309 } 1310 break; 1311 } 1312 case kSubpassLoad_SpecialIntrinsic: { 1313 SpvId img = this->writeExpression(*c.fArguments[0], out); 1314 std::vector<std::unique_ptr<Expression>> args; 1315 args.emplace_back(new FloatLiteral(fContext, Position(), 0.0)); 1316 args.emplace_back(new FloatLiteral(fContext, Position(), 0.0)); 1317 Constructor ctor(Position(), *fContext.fVec2_Type, std::move(args)); 1318 SpvId coords = this->writeConstantVector(ctor); 1319 if (1 == c.fArguments.size()) { 1320 this->writeInstruction(SpvOpImageRead, 1321 this->getType(c.fType), 1322 result, 1323 img, 1324 coords, 1325 out); 1326 } else { 1327 ASSERT(2 == c.fArguments.size()); 1328 SpvId sample = this->writeExpression(*c.fArguments[1], out); 1329 this->writeInstruction(SpvOpImageRead, 1330 this->getType(c.fType), 1331 result, 1332 img, 1333 coords, 1334 SpvImageOperandsSampleMask, 1335 sample, 1336 out); 1337 } 1338 break; 1339 } 1340 case kTexelFetch_SpecialIntrinsic: { 1341 ASSERT(c.fArguments.size() == 2); 1342 SpvId image = this->nextId(); 1343 this->writeInstruction(SpvOpImage, 1344 this->getImageType(c.fArguments[0]->fType), 1345 image, 1346 this->writeExpression(*c.fArguments[0], out), 1347 out); 1348 this->writeInstruction(SpvOpImageFetch, 1349 this->getType(c.fType), 1350 result, 1351 image, 1352 this->writeExpression(*c.fArguments[1], out), 1353 out); 1354 break; 1355 } 1356 case kTexture_SpecialIntrinsic: { 1357 SpvOp_ op = SpvOpImageSampleImplicitLod; 1358 switch (c.fArguments[0]->fType.dimensions()) { 1359 case SpvDim1D: 1360 if (c.fArguments[1]->fType == *fContext.fVec2_Type) { 1361 op = SpvOpImageSampleProjImplicitLod; 1362 } else { 1363 ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type); 1364 } 1365 break; 1366 case SpvDim2D: 1367 if (c.fArguments[1]->fType == *fContext.fVec3_Type) { 1368 op = SpvOpImageSampleProjImplicitLod; 1369 } else { 1370 ASSERT(c.fArguments[1]->fType == *fContext.fVec2_Type); 1371 } 1372 break; 1373 case SpvDim3D: 1374 if (c.fArguments[1]->fType == *fContext.fVec4_Type) { 1375 op = SpvOpImageSampleProjImplicitLod; 1376 } else { 1377 ASSERT(c.fArguments[1]->fType == *fContext.fVec3_Type); 1378 } 1379 break; 1380 case SpvDimCube: // fall through 1381 case SpvDimRect: // fall through 1382 case SpvDimBuffer: // fall through 1383 case SpvDimSubpassData: 1384 break; 1385 } 1386 SpvId type = this->getType(c.fType); 1387 SpvId sampler = this->writeExpression(*c.fArguments[0], out); 1388 SpvId uv = this->writeExpression(*c.fArguments[1], out); 1389 if (c.fArguments.size() == 3) { 1390 this->writeInstruction(op, type, result, sampler, uv, 1391 SpvImageOperandsBiasMask, 1392 this->writeExpression(*c.fArguments[2], out), 1393 out); 1394 } else { 1395 ASSERT(c.fArguments.size() == 2); 1396 this->writeInstruction(op, type, result, sampler, uv, 1397 out); 1398 } 1399 break; 1400 } 1401 } 1402 return result; 1403 } 1404 1405 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) { 1406 const auto& entry = fFunctionMap.find(&c.fFunction); 1407 if (entry == fFunctionMap.end()) { 1408 return this->writeIntrinsicCall(c, out); 1409 } 1410 // stores (variable, type, lvalue) pairs to extract and save after the function call is complete 1411 std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues; 1412 std::vector<SpvId> arguments; 1413 for (size_t i = 0; i < c.fArguments.size(); i++) { 1414 // id of temporary variable that we will use to hold this argument, or 0 if it is being 1415 // passed directly 1416 SpvId tmpVar; 1417 // if we need a temporary var to store this argument, this is the value to store in the var 1418 SpvId tmpValueId; 1419 if (is_out(*c.fFunction.fParameters[i])) { 1420 std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out); 1421 SpvId ptr = lv->getPointer(); 1422 if (ptr) { 1423 arguments.push_back(ptr); 1424 continue; 1425 } else { 1426 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to 1427 // copy it into a temp, call the function, read the value out of the temp, and then 1428 // update the lvalue. 1429 tmpValueId = lv->load(out); 1430 tmpVar = this->nextId(); 1431 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType), 1432 std::move(lv))); 1433 } 1434 } else { 1435 // see getFunctionType for an explanation of why we're always using pointer parameters 1436 tmpValueId = this->writeExpression(*c.fArguments[i], out); 1437 tmpVar = this->nextId(); 1438 } 1439 this->writeInstruction(SpvOpVariable, 1440 this->getPointerType(c.fArguments[i]->fType, 1441 SpvStorageClassFunction), 1442 tmpVar, 1443 SpvStorageClassFunction, 1444 fVariableBuffer); 1445 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out); 1446 arguments.push_back(tmpVar); 1447 } 1448 SpvId result = this->nextId(); 1449 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out); 1450 this->writeWord(this->getType(c.fType), out); 1451 this->writeWord(result, out); 1452 this->writeWord(entry->second, out); 1453 for (SpvId id : arguments) { 1454 this->writeWord(id, out); 1455 } 1456 // now that the call is complete, we may need to update some lvalues with the new values of out 1457 // arguments 1458 for (const auto& tuple : lvalues) { 1459 SpvId load = this->nextId(); 1460 this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out); 1461 std::get<2>(tuple)->store(load, out); 1462 } 1463 return result; 1464 } 1465 1466 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) { 1467 ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant()); 1468 SpvId result = this->nextId(); 1469 std::vector<SpvId> arguments; 1470 for (size_t i = 0; i < c.fArguments.size(); i++) { 1471 arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer)); 1472 } 1473 SpvId type = this->getType(c.fType); 1474 if (c.fArguments.size() == 1) { 1475 // with a single argument, a vector will have all of its entries equal to the argument 1476 this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer); 1477 this->writeWord(type, fConstantBuffer); 1478 this->writeWord(result, fConstantBuffer); 1479 for (int i = 0; i < c.fType.columns(); i++) { 1480 this->writeWord(arguments[0], fConstantBuffer); 1481 } 1482 } else { 1483 this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(), 1484 fConstantBuffer); 1485 this->writeWord(type, fConstantBuffer); 1486 this->writeWord(result, fConstantBuffer); 1487 for (SpvId id : arguments) { 1488 this->writeWord(id, fConstantBuffer); 1489 } 1490 } 1491 return result; 1492 } 1493 1494 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) { 1495 ASSERT(c.fType == *fContext.fFloat_Type); 1496 ASSERT(c.fArguments.size() == 1); 1497 ASSERT(c.fArguments[0]->fType.isNumber()); 1498 SpvId result = this->nextId(); 1499 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 1500 if (c.fArguments[0]->fType == *fContext.fInt_Type) { 1501 this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter, 1502 out); 1503 } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) { 1504 this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter, 1505 out); 1506 } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) { 1507 return parameter; 1508 } 1509 return result; 1510 } 1511 1512 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) { 1513 ASSERT(c.fType == *fContext.fInt_Type); 1514 ASSERT(c.fArguments.size() == 1); 1515 ASSERT(c.fArguments[0]->fType.isNumber()); 1516 SpvId result = this->nextId(); 1517 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 1518 if (c.fArguments[0]->fType == *fContext.fFloat_Type) { 1519 this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter, 1520 out); 1521 } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) { 1522 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter, 1523 out); 1524 } else if (c.fArguments[0]->fType == *fContext.fInt_Type) { 1525 return parameter; 1526 } 1527 return result; 1528 } 1529 1530 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) { 1531 ASSERT(c.fType == *fContext.fUInt_Type); 1532 ASSERT(c.fArguments.size() == 1); 1533 ASSERT(c.fArguments[0]->fType.isNumber()); 1534 SpvId result = this->nextId(); 1535 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 1536 if (c.fArguments[0]->fType == *fContext.fFloat_Type) { 1537 this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter, 1538 out); 1539 } else if (c.fArguments[0]->fType == *fContext.fInt_Type) { 1540 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter, 1541 out); 1542 } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) { 1543 return parameter; 1544 } 1545 return result; 1546 } 1547 1548 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, 1549 OutputStream& out) { 1550 FloatLiteral zero(fContext, Position(), 0); 1551 SpvId zeroId = this->writeFloatLiteral(zero); 1552 std::vector<SpvId> columnIds; 1553 for (int column = 0; column < type.columns(); column++) { 1554 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(), 1555 out); 1556 this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)), 1557 out); 1558 SpvId columnId = this->nextId(); 1559 this->writeWord(columnId, out); 1560 columnIds.push_back(columnId); 1561 for (int row = 0; row < type.columns(); row++) { 1562 this->writeWord(row == column ? diagonal : zeroId, out); 1563 } 1564 } 1565 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), 1566 out); 1567 this->writeWord(this->getType(type), out); 1568 this->writeWord(id, out); 1569 for (SpvId id : columnIds) { 1570 this->writeWord(id, out); 1571 } 1572 } 1573 1574 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType, 1575 const Type& dstType, OutputStream& out) { 1576 ABORT("unimplemented"); 1577 } 1578 1579 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) { 1580 ASSERT(c.fType.kind() == Type::kMatrix_Kind); 1581 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1582 // an instruction 1583 std::vector<SpvId> arguments; 1584 for (size_t i = 0; i < c.fArguments.size(); i++) { 1585 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1586 } 1587 SpvId result = this->nextId(); 1588 int rows = c.fType.rows(); 1589 int columns = c.fType.columns(); 1590 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { 1591 this->writeUniformScaleMatrix(result, arguments[0], c.fType, out); 1592 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) { 1593 this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out); 1594 } else { 1595 std::vector<SpvId> columnIds; 1596 int currentCount = 0; 1597 for (size_t i = 0; i < arguments.size(); i++) { 1598 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) { 1599 ASSERT(currentCount == 0); 1600 columnIds.push_back(arguments[i]); 1601 currentCount = 0; 1602 } else { 1603 ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind); 1604 if (currentCount == 0) { 1605 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out); 1606 this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, 1607 1)), 1608 out); 1609 SpvId id = this->nextId(); 1610 this->writeWord(id, out); 1611 columnIds.push_back(id); 1612 } 1613 this->writeWord(arguments[i], out); 1614 currentCount = (currentCount + 1) % rows; 1615 } 1616 } 1617 ASSERT(columnIds.size() == (size_t) columns); 1618 this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out); 1619 this->writeWord(this->getType(c.fType), out); 1620 this->writeWord(result, out); 1621 for (SpvId id : columnIds) { 1622 this->writeWord(id, out); 1623 } 1624 } 1625 return result; 1626 } 1627 1628 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) { 1629 ASSERT(c.fType.kind() == Type::kVector_Kind); 1630 if (c.isConstant()) { 1631 return this->writeConstantVector(c); 1632 } 1633 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1634 // an instruction 1635 std::vector<SpvId> arguments; 1636 for (size_t i = 0; i < c.fArguments.size(); i++) { 1637 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1638 } 1639 SpvId result = this->nextId(); 1640 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { 1641 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out); 1642 this->writeWord(this->getType(c.fType), out); 1643 this->writeWord(result, out); 1644 for (int i = 0; i < c.fType.columns(); i++) { 1645 this->writeWord(arguments[0], out); 1646 } 1647 } else { 1648 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out); 1649 this->writeWord(this->getType(c.fType), out); 1650 this->writeWord(result, out); 1651 for (SpvId id : arguments) { 1652 this->writeWord(id, out); 1653 } 1654 } 1655 return result; 1656 } 1657 1658 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) { 1659 if (c.fType == *fContext.fFloat_Type) { 1660 return this->writeFloatConstructor(c, out); 1661 } else if (c.fType == *fContext.fInt_Type) { 1662 return this->writeIntConstructor(c, out); 1663 } else if (c.fType == *fContext.fUInt_Type) { 1664 return this->writeUIntConstructor(c, out); 1665 } 1666 switch (c.fType.kind()) { 1667 case Type::kVector_Kind: 1668 return this->writeVectorConstructor(c, out); 1669 case Type::kMatrix_Kind: 1670 return this->writeMatrixConstructor(c, out); 1671 default: 1672 ABORT("unsupported constructor: %s", c.description().c_str()); 1673 } 1674 } 1675 1676 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) { 1677 if (modifiers.fFlags & Modifiers::kIn_Flag) { 1678 ASSERT(!modifiers.fLayout.fPushConstant); 1679 return SpvStorageClassInput; 1680 } else if (modifiers.fFlags & Modifiers::kOut_Flag) { 1681 ASSERT(!modifiers.fLayout.fPushConstant); 1682 return SpvStorageClassOutput; 1683 } else if (modifiers.fFlags & Modifiers::kUniform_Flag) { 1684 if (modifiers.fLayout.fPushConstant) { 1685 return SpvStorageClassPushConstant; 1686 } 1687 return SpvStorageClassUniform; 1688 } else { 1689 return SpvStorageClassFunction; 1690 } 1691 } 1692 1693 SpvStorageClass_ get_storage_class(const Expression& expr) { 1694 switch (expr.fKind) { 1695 case Expression::kVariableReference_Kind: { 1696 const Variable& var = ((VariableReference&) expr).fVariable; 1697 if (var.fStorage != Variable::kGlobal_Storage) { 1698 return SpvStorageClassFunction; 1699 } 1700 return get_storage_class(var.fModifiers); 1701 } 1702 case Expression::kFieldAccess_Kind: 1703 return get_storage_class(*((FieldAccess&) expr).fBase); 1704 case Expression::kIndex_Kind: 1705 return get_storage_class(*((IndexExpression&) expr).fBase); 1706 default: 1707 return SpvStorageClassFunction; 1708 } 1709 } 1710 1711 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) { 1712 std::vector<SpvId> chain; 1713 switch (expr.fKind) { 1714 case Expression::kIndex_Kind: { 1715 IndexExpression& indexExpr = (IndexExpression&) expr; 1716 chain = this->getAccessChain(*indexExpr.fBase, out); 1717 chain.push_back(this->writeExpression(*indexExpr.fIndex, out)); 1718 break; 1719 } 1720 case Expression::kFieldAccess_Kind: { 1721 FieldAccess& fieldExpr = (FieldAccess&) expr; 1722 chain = this->getAccessChain(*fieldExpr.fBase, out); 1723 IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex); 1724 chain.push_back(this->writeIntLiteral(index)); 1725 break; 1726 } 1727 default: 1728 chain.push_back(this->getLValue(expr, out)->getPointer()); 1729 } 1730 return chain; 1731 } 1732 1733 class PointerLValue : public SPIRVCodeGenerator::LValue { 1734 public: 1735 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type) 1736 : fGen(gen) 1737 , fPointer(pointer) 1738 , fType(type) {} 1739 1740 virtual SpvId getPointer() override { 1741 return fPointer; 1742 } 1743 1744 virtual SpvId load(OutputStream& out) override { 1745 SpvId result = fGen.nextId(); 1746 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out); 1747 return result; 1748 } 1749 1750 virtual void store(SpvId value, OutputStream& out) override { 1751 fGen.writeInstruction(SpvOpStore, fPointer, value, out); 1752 } 1753 1754 private: 1755 SPIRVCodeGenerator& fGen; 1756 const SpvId fPointer; 1757 const SpvId fType; 1758 }; 1759 1760 class SwizzleLValue : public SPIRVCodeGenerator::LValue { 1761 public: 1762 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components, 1763 const Type& baseType, const Type& swizzleType) 1764 : fGen(gen) 1765 , fVecPointer(vecPointer) 1766 , fComponents(components) 1767 , fBaseType(baseType) 1768 , fSwizzleType(swizzleType) {} 1769 1770 virtual SpvId getPointer() override { 1771 return 0; 1772 } 1773 1774 virtual SpvId load(OutputStream& out) override { 1775 SpvId base = fGen.nextId(); 1776 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out); 1777 SpvId result = fGen.nextId(); 1778 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out); 1779 fGen.writeWord(fGen.getType(fSwizzleType), out); 1780 fGen.writeWord(result, out); 1781 fGen.writeWord(base, out); 1782 fGen.writeWord(base, out); 1783 for (int component : fComponents) { 1784 fGen.writeWord(component, out); 1785 } 1786 return result; 1787 } 1788 1789 virtual void store(SpvId value, OutputStream& out) override { 1790 // use OpVectorShuffle to mix and match the vector components. We effectively create 1791 // a virtual vector out of the concatenation of the left and right vectors, and then 1792 // select components from this virtual vector to make the result vector. For 1793 // instance, given: 1794 // vec3 L = ...; 1795 // vec3 R = ...; 1796 // L.xz = R.xy; 1797 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want 1798 // our result vector to look like (R.x, L.y, R.y), so we need to select indices 1799 // (3, 1, 4). 1800 SpvId base = fGen.nextId(); 1801 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out); 1802 SpvId shuffle = fGen.nextId(); 1803 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out); 1804 fGen.writeWord(fGen.getType(fBaseType), out); 1805 fGen.writeWord(shuffle, out); 1806 fGen.writeWord(base, out); 1807 fGen.writeWord(value, out); 1808 for (int i = 0; i < fBaseType.columns(); i++) { 1809 // current offset into the virtual vector, defaults to pulling the unmodified 1810 // value from the left side 1811 int offset = i; 1812 // check to see if we are writing this component 1813 for (size_t j = 0; j < fComponents.size(); j++) { 1814 if (fComponents[j] == i) { 1815 // we're writing to this component, so adjust the offset to pull from 1816 // the correct component of the right side instead of preserving the 1817 // value from the left 1818 offset = (int) (j + fBaseType.columns()); 1819 break; 1820 } 1821 } 1822 fGen.writeWord(offset, out); 1823 } 1824 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out); 1825 } 1826 1827 private: 1828 SPIRVCodeGenerator& fGen; 1829 const SpvId fVecPointer; 1830 const std::vector<int>& fComponents; 1831 const Type& fBaseType; 1832 const Type& fSwizzleType; 1833 }; 1834 1835 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr, 1836 OutputStream& out) { 1837 switch (expr.fKind) { 1838 case Expression::kVariableReference_Kind: { 1839 const Variable& var = ((VariableReference&) expr).fVariable; 1840 auto entry = fVariableMap.find(&var); 1841 ASSERT(entry != fVariableMap.end()); 1842 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1843 *this, 1844 entry->second, 1845 this->getType(expr.fType))); 1846 } 1847 case Expression::kIndex_Kind: // fall through 1848 case Expression::kFieldAccess_Kind: { 1849 std::vector<SpvId> chain = this->getAccessChain(expr, out); 1850 SpvId member = this->nextId(); 1851 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out); 1852 this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out); 1853 this->writeWord(member, out); 1854 for (SpvId idx : chain) { 1855 this->writeWord(idx, out); 1856 } 1857 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1858 *this, 1859 member, 1860 this->getType(expr.fType))); 1861 } 1862 1863 case Expression::kSwizzle_Kind: { 1864 Swizzle& swizzle = (Swizzle&) expr; 1865 size_t count = swizzle.fComponents.size(); 1866 SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer(); 1867 ASSERT(base); 1868 if (count == 1) { 1869 IntLiteral index(fContext, Position(), swizzle.fComponents[0]); 1870 SpvId member = this->nextId(); 1871 this->writeInstruction(SpvOpAccessChain, 1872 this->getPointerType(swizzle.fType, 1873 get_storage_class(*swizzle.fBase)), 1874 member, 1875 base, 1876 this->writeIntLiteral(index), 1877 out); 1878 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1879 *this, 1880 member, 1881 this->getType(expr.fType))); 1882 } else { 1883 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue( 1884 *this, 1885 base, 1886 swizzle.fComponents, 1887 swizzle.fBase->fType, 1888 expr.fType)); 1889 } 1890 } 1891 1892 default: 1893 // expr isn't actually an lvalue, create a dummy variable for it. This case happens due 1894 // to the need to store values in temporary variables during function calls (see 1895 // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been 1896 // caught by IRGenerator 1897 SpvId result = this->nextId(); 1898 SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction); 1899 this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, 1900 fVariableBuffer); 1901 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out); 1902 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1903 *this, 1904 result, 1905 this->getType(expr.fType))); 1906 } 1907 } 1908 1909 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) { 1910 SpvId result = this->nextId(); 1911 auto entry = fVariableMap.find(&ref.fVariable); 1912 ASSERT(entry != fVariableMap.end()); 1913 SpvId var = entry->second; 1914 this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out); 1915 if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN && 1916 fProgram.fSettings.fFlipY) { 1917 // need to remap to a top-left coordinate system 1918 if (fRTHeightStructId == (SpvId) -1) { 1919 // height variable hasn't been written yet 1920 std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors)); 1921 ASSERT(fRTHeightFieldIndex == (SpvId) -1); 1922 std::vector<Type::Field> fields; 1923 fields.emplace_back(Modifiers(), String(SKSL_RTHEIGHT_NAME), 1924 fContext.fFloat_Type.get()); 1925 String name("sksl_synthetic_uniforms"); 1926 Type intfStruct(Position(), name, fields); 1927 Layout layout(-1, -1, 1, -1, -1, -1, -1, false, false, false, 1928 Layout::Format::kUnspecified, false, Layout::kUnspecified_Primitive, -1, 1929 -1, "", Layout::kNo_Key); 1930 Variable* intfVar = new Variable(Position(), 1931 Modifiers(layout, Modifiers::kUniform_Flag), 1932 name, 1933 intfStruct, 1934 Variable::kGlobal_Storage); 1935 fSynthetics.takeOwnership(intfVar); 1936 InterfaceBlock intf(Position(), intfVar, name, String(""), 1937 std::vector<std::unique_ptr<Expression>>(), st); 1938 fRTHeightStructId = this->writeInterfaceBlock(intf); 1939 fRTHeightFieldIndex = 0; 1940 } 1941 ASSERT(fRTHeightFieldIndex != (SpvId) -1); 1942 // write vec4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0) 1943 SpvId xId = this->nextId(); 1944 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId, 1945 result, 0, out); 1946 IntLiteral fieldIndex(fContext, Position(), fRTHeightFieldIndex); 1947 SpvId fieldIndexId = this->writeIntLiteral(fieldIndex); 1948 SpvId heightPtr = this->nextId(); 1949 this->writeOpCode(SpvOpAccessChain, 5, out); 1950 this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out); 1951 this->writeWord(heightPtr, out); 1952 this->writeWord(fRTHeightStructId, out); 1953 this->writeWord(fieldIndexId, out); 1954 SpvId heightRead = this->nextId(); 1955 this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead, 1956 heightPtr, out); 1957 SpvId rawYId = this->nextId(); 1958 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId, 1959 result, 1, out); 1960 SpvId flippedYId = this->nextId(); 1961 this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId, 1962 heightRead, rawYId, out); 1963 FloatLiteral zero(fContext, Position(), 0.0); 1964 SpvId zeroId = writeFloatLiteral(zero); 1965 FloatLiteral one(fContext, Position(), 1.0); 1966 SpvId oneId = writeFloatLiteral(one); 1967 SpvId flipped = this->nextId(); 1968 this->writeOpCode(SpvOpCompositeConstruct, 7, out); 1969 this->writeWord(this->getType(*fContext.fVec4_Type), out); 1970 this->writeWord(flipped, out); 1971 this->writeWord(xId, out); 1972 this->writeWord(flippedYId, out); 1973 this->writeWord(zeroId, out); 1974 this->writeWord(oneId, out); 1975 return flipped; 1976 } 1977 return result; 1978 } 1979 1980 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) { 1981 return getLValue(expr, out)->load(out); 1982 } 1983 1984 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) { 1985 return getLValue(f, out)->load(out); 1986 } 1987 1988 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) { 1989 SpvId base = this->writeExpression(*swizzle.fBase, out); 1990 SpvId result = this->nextId(); 1991 size_t count = swizzle.fComponents.size(); 1992 if (count == 1) { 1993 this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base, 1994 swizzle.fComponents[0], out); 1995 } else { 1996 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out); 1997 this->writeWord(this->getType(swizzle.fType), out); 1998 this->writeWord(result, out); 1999 this->writeWord(base, out); 2000 this->writeWord(base, out); 2001 for (int component : swizzle.fComponents) { 2002 this->writeWord(component, out); 2003 } 2004 } 2005 return result; 2006 } 2007 2008 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, 2009 const Type& operandType, SpvId lhs, 2010 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, 2011 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) { 2012 SpvId result = this->nextId(); 2013 if (is_float(fContext, operandType)) { 2014 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out); 2015 } else if (is_signed(fContext, operandType)) { 2016 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out); 2017 } else if (is_unsigned(fContext, operandType)) { 2018 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out); 2019 } else if (operandType == *fContext.fBool_Type) { 2020 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out); 2021 } else { 2022 ABORT("invalid operandType: %s", operandType.description().c_str()); 2023 } 2024 return result; 2025 } 2026 2027 bool is_assignment(Token::Kind op) { 2028 switch (op) { 2029 case Token::EQ: // fall through 2030 case Token::PLUSEQ: // fall through 2031 case Token::MINUSEQ: // fall through 2032 case Token::STAREQ: // fall through 2033 case Token::SLASHEQ: // fall through 2034 case Token::PERCENTEQ: // fall through 2035 case Token::SHLEQ: // fall through 2036 case Token::SHREQ: // fall through 2037 case Token::BITWISEOREQ: // fall through 2038 case Token::BITWISEXOREQ: // fall through 2039 case Token::BITWISEANDEQ: // fall through 2040 case Token::LOGICALOREQ: // fall through 2041 case Token::LOGICALXOREQ: // fall through 2042 case Token::LOGICALANDEQ: 2043 return true; 2044 default: 2045 return false; 2046 } 2047 } 2048 2049 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) { 2050 if (operandType.kind() == Type::kVector_Kind) { 2051 SpvId result = this->nextId(); 2052 this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out); 2053 return result; 2054 } 2055 return id; 2056 } 2057 2058 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, 2059 SpvOp_ floatOperator, SpvOp_ intOperator, 2060 OutputStream& out) { 2061 SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator; 2062 ASSERT(operandType.kind() == Type::kMatrix_Kind); 2063 SpvId rowType = this->getType(operandType.componentType().toCompound(fContext, 2064 operandType.columns(), 2065 1)); 2066 SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext, 2067 operandType.columns(), 2068 1)); 2069 SpvId boolType = this->getType(*fContext.fBool_Type); 2070 SpvId result = 0; 2071 for (int i = 0; i < operandType.rows(); i++) { 2072 SpvId rowL = this->nextId(); 2073 this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out); 2074 SpvId rowR = this->nextId(); 2075 this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out); 2076 SpvId compare = this->nextId(); 2077 this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out); 2078 SpvId all = this->nextId(); 2079 this->writeInstruction(SpvOpAll, boolType, all, compare, out); 2080 if (result != 0) { 2081 SpvId next = this->nextId(); 2082 this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out); 2083 result = next; 2084 } 2085 else { 2086 result = all; 2087 } 2088 } 2089 return result; 2090 } 2091 2092 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) { 2093 // handle cases where we don't necessarily evaluate both LHS and RHS 2094 switch (b.fOperator) { 2095 case Token::EQ: { 2096 SpvId rhs = this->writeExpression(*b.fRight, out); 2097 this->getLValue(*b.fLeft, out)->store(rhs, out); 2098 return rhs; 2099 } 2100 case Token::LOGICALAND: 2101 return this->writeLogicalAnd(b, out); 2102 case Token::LOGICALOR: 2103 return this->writeLogicalOr(b, out); 2104 default: 2105 break; 2106 } 2107 2108 // "normal" operators 2109 const Type& resultType = b.fType; 2110 std::unique_ptr<LValue> lvalue; 2111 SpvId lhs; 2112 if (is_assignment(b.fOperator)) { 2113 lvalue = this->getLValue(*b.fLeft, out); 2114 lhs = lvalue->load(out); 2115 } else { 2116 lvalue = nullptr; 2117 lhs = this->writeExpression(*b.fLeft, out); 2118 } 2119 SpvId rhs = this->writeExpression(*b.fRight, out); 2120 if (b.fOperator == Token::COMMA) { 2121 return rhs; 2122 } 2123 // component type we are operating on: float, int, uint 2124 const Type* operandType; 2125 // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling 2126 // in SPIR-V 2127 if (b.fLeft->fType != b.fRight->fType) { 2128 if (b.fLeft->fType.kind() == Type::kVector_Kind && 2129 b.fRight->fType.isNumber()) { 2130 // promote number to vector 2131 SpvId vec = this->nextId(); 2132 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); 2133 this->writeWord(this->getType(resultType), out); 2134 this->writeWord(vec, out); 2135 for (int i = 0; i < resultType.columns(); i++) { 2136 this->writeWord(rhs, out); 2137 } 2138 rhs = vec; 2139 operandType = &b.fRight->fType; 2140 } else if (b.fRight->fType.kind() == Type::kVector_Kind && 2141 b.fLeft->fType.isNumber()) { 2142 // promote number to vector 2143 SpvId vec = this->nextId(); 2144 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); 2145 this->writeWord(this->getType(resultType), out); 2146 this->writeWord(vec, out); 2147 for (int i = 0; i < resultType.columns(); i++) { 2148 this->writeWord(lhs, out); 2149 } 2150 lhs = vec; 2151 ASSERT(!lvalue); 2152 operandType = &b.fLeft->fType; 2153 } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) { 2154 SpvOp_ op; 2155 if (b.fRight->fType.kind() == Type::kMatrix_Kind) { 2156 op = SpvOpMatrixTimesMatrix; 2157 } else if (b.fRight->fType.kind() == Type::kVector_Kind) { 2158 op = SpvOpMatrixTimesVector; 2159 } else { 2160 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind); 2161 op = SpvOpMatrixTimesScalar; 2162 } 2163 SpvId result = this->nextId(); 2164 this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out); 2165 if (b.fOperator == Token::STAREQ) { 2166 lvalue->store(result, out); 2167 } else { 2168 ASSERT(b.fOperator == Token::STAR); 2169 } 2170 return result; 2171 } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) { 2172 SpvId result = this->nextId(); 2173 if (b.fLeft->fType.kind() == Type::kVector_Kind) { 2174 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result, 2175 lhs, rhs, out); 2176 } else { 2177 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind); 2178 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs, 2179 lhs, out); 2180 } 2181 if (b.fOperator == Token::STAREQ) { 2182 lvalue->store(result, out); 2183 } else { 2184 ASSERT(b.fOperator == Token::STAR); 2185 } 2186 return result; 2187 } else { 2188 ABORT("unsupported binary expression: %s", b.description().c_str()); 2189 } 2190 } else { 2191 operandType = &b.fLeft->fType; 2192 ASSERT(*operandType == b.fRight->fType); 2193 } 2194 switch (b.fOperator) { 2195 case Token::EQEQ: { 2196 if (operandType->kind() == Type::kMatrix_Kind) { 2197 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual, 2198 SpvOpIEqual, out); 2199 } 2200 ASSERT(resultType == *fContext.fBool_Type); 2201 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2202 SpvOpFOrdEqual, SpvOpIEqual, 2203 SpvOpIEqual, SpvOpLogicalEqual, out), 2204 *operandType, out); 2205 } 2206 case Token::NEQ: 2207 if (operandType->kind() == Type::kMatrix_Kind) { 2208 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual, 2209 SpvOpINotEqual, out); 2210 } 2211 ASSERT(resultType == *fContext.fBool_Type); 2212 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2213 SpvOpFOrdNotEqual, SpvOpINotEqual, 2214 SpvOpINotEqual, SpvOpLogicalNotEqual, 2215 out), 2216 *operandType, out); 2217 case Token::GT: 2218 ASSERT(resultType == *fContext.fBool_Type); 2219 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2220 SpvOpFOrdGreaterThan, SpvOpSGreaterThan, 2221 SpvOpUGreaterThan, SpvOpUndef, out); 2222 case Token::LT: 2223 ASSERT(resultType == *fContext.fBool_Type); 2224 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan, 2225 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out); 2226 case Token::GTEQ: 2227 ASSERT(resultType == *fContext.fBool_Type); 2228 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2229 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual, 2230 SpvOpUGreaterThanEqual, SpvOpUndef, out); 2231 case Token::LTEQ: 2232 ASSERT(resultType == *fContext.fBool_Type); 2233 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2234 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual, 2235 SpvOpULessThanEqual, SpvOpUndef, out); 2236 case Token::PLUS: 2237 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, 2238 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 2239 case Token::MINUS: 2240 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, 2241 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2242 case Token::STAR: 2243 if (b.fLeft->fType.kind() == Type::kMatrix_Kind && 2244 b.fRight->fType.kind() == Type::kMatrix_Kind) { 2245 // matrix multiply 2246 SpvId result = this->nextId(); 2247 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, 2248 lhs, rhs, out); 2249 return result; 2250 } 2251 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul, 2252 SpvOpIMul, SpvOpIMul, SpvOpUndef, out); 2253 case Token::SLASH: 2254 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv, 2255 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out); 2256 case Token::PERCENT: 2257 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod, 2258 SpvOpSMod, SpvOpUMod, SpvOpUndef, out); 2259 case Token::SHL: 2260 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2261 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical, 2262 SpvOpUndef, out); 2263 case Token::SHR: 2264 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2265 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical, 2266 SpvOpUndef, out); 2267 case Token::BITWISEAND: 2268 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2269 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out); 2270 case Token::BITWISEOR: 2271 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2272 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out); 2273 case Token::BITWISEXOR: 2274 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2275 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out); 2276 case Token::PLUSEQ: { 2277 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, 2278 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 2279 ASSERT(lvalue); 2280 lvalue->store(result, out); 2281 return result; 2282 } 2283 case Token::MINUSEQ: { 2284 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, 2285 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2286 ASSERT(lvalue); 2287 lvalue->store(result, out); 2288 return result; 2289 } 2290 case Token::STAREQ: { 2291 if (b.fLeft->fType.kind() == Type::kMatrix_Kind && 2292 b.fRight->fType.kind() == Type::kMatrix_Kind) { 2293 // matrix multiply 2294 SpvId result = this->nextId(); 2295 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, 2296 lhs, rhs, out); 2297 ASSERT(lvalue); 2298 lvalue->store(result, out); 2299 return result; 2300 } 2301 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul, 2302 SpvOpIMul, SpvOpIMul, SpvOpUndef, out); 2303 ASSERT(lvalue); 2304 lvalue->store(result, out); 2305 return result; 2306 } 2307 case Token::SLASHEQ: { 2308 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv, 2309 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out); 2310 ASSERT(lvalue); 2311 lvalue->store(result, out); 2312 return result; 2313 } 2314 case Token::PERCENTEQ: { 2315 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod, 2316 SpvOpSMod, SpvOpUMod, SpvOpUndef, out); 2317 ASSERT(lvalue); 2318 lvalue->store(result, out); 2319 return result; 2320 } 2321 case Token::SHLEQ: { 2322 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2323 SpvOpUndef, SpvOpShiftLeftLogical, 2324 SpvOpShiftLeftLogical, SpvOpUndef, out); 2325 ASSERT(lvalue); 2326 lvalue->store(result, out); 2327 return result; 2328 } 2329 case Token::SHREQ: { 2330 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2331 SpvOpUndef, SpvOpShiftRightArithmetic, 2332 SpvOpShiftRightLogical, SpvOpUndef, out); 2333 ASSERT(lvalue); 2334 lvalue->store(result, out); 2335 return result; 2336 } 2337 case Token::BITWISEANDEQ: { 2338 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2339 SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd, 2340 SpvOpUndef, out); 2341 ASSERT(lvalue); 2342 lvalue->store(result, out); 2343 return result; 2344 } 2345 case Token::BITWISEOREQ: { 2346 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2347 SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr, 2348 SpvOpUndef, out); 2349 ASSERT(lvalue); 2350 lvalue->store(result, out); 2351 return result; 2352 } 2353 case Token::BITWISEXOREQ: { 2354 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2355 SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor, 2356 SpvOpUndef, out); 2357 ASSERT(lvalue); 2358 lvalue->store(result, out); 2359 return result; 2360 } 2361 default: 2362 ABORT("unsupported binary expression: %s", b.description().c_str()); 2363 } 2364 } 2365 2366 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) { 2367 ASSERT(a.fOperator == Token::LOGICALAND); 2368 BoolLiteral falseLiteral(fContext, Position(), false); 2369 SpvId falseConstant = this->writeBoolLiteral(falseLiteral); 2370 SpvId lhs = this->writeExpression(*a.fLeft, out); 2371 SpvId rhsLabel = this->nextId(); 2372 SpvId end = this->nextId(); 2373 SpvId lhsBlock = fCurrentBlock; 2374 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2375 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out); 2376 this->writeLabel(rhsLabel, out); 2377 SpvId rhs = this->writeExpression(*a.fRight, out); 2378 SpvId rhsBlock = fCurrentBlock; 2379 this->writeInstruction(SpvOpBranch, end, out); 2380 this->writeLabel(end, out); 2381 SpvId result = this->nextId(); 2382 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant, 2383 lhsBlock, rhs, rhsBlock, out); 2384 return result; 2385 } 2386 2387 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) { 2388 ASSERT(o.fOperator == Token::LOGICALOR); 2389 BoolLiteral trueLiteral(fContext, Position(), true); 2390 SpvId trueConstant = this->writeBoolLiteral(trueLiteral); 2391 SpvId lhs = this->writeExpression(*o.fLeft, out); 2392 SpvId rhsLabel = this->nextId(); 2393 SpvId end = this->nextId(); 2394 SpvId lhsBlock = fCurrentBlock; 2395 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2396 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out); 2397 this->writeLabel(rhsLabel, out); 2398 SpvId rhs = this->writeExpression(*o.fRight, out); 2399 SpvId rhsBlock = fCurrentBlock; 2400 this->writeInstruction(SpvOpBranch, end, out); 2401 this->writeLabel(end, out); 2402 SpvId result = this->nextId(); 2403 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant, 2404 lhsBlock, rhs, rhsBlock, out); 2405 return result; 2406 } 2407 2408 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) { 2409 SpvId test = this->writeExpression(*t.fTest, out); 2410 if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) { 2411 // both true and false are constants, can just use OpSelect 2412 SpvId result = this->nextId(); 2413 SpvId trueId = this->writeExpression(*t.fIfTrue, out); 2414 SpvId falseId = this->writeExpression(*t.fIfFalse, out); 2415 this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId, 2416 out); 2417 return result; 2418 } 2419 // was originally using OpPhi to choose the result, but for some reason that is crashing on 2420 // Adreno. Switched to storing the result in a temp variable as glslang does. 2421 SpvId var = this->nextId(); 2422 this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction), 2423 var, SpvStorageClassFunction, fVariableBuffer); 2424 SpvId trueLabel = this->nextId(); 2425 SpvId falseLabel = this->nextId(); 2426 SpvId end = this->nextId(); 2427 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2428 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out); 2429 this->writeLabel(trueLabel, out); 2430 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out); 2431 this->writeInstruction(SpvOpBranch, end, out); 2432 this->writeLabel(falseLabel, out); 2433 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out); 2434 this->writeInstruction(SpvOpBranch, end, out); 2435 this->writeLabel(end, out); 2436 SpvId result = this->nextId(); 2437 this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out); 2438 return result; 2439 } 2440 2441 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) { 2442 if (type == *context.fInt_Type) { 2443 return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1)); 2444 } 2445 else if (type == *context.fFloat_Type) { 2446 return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0)); 2447 } else { 2448 ABORT("math is unsupported on type '%s'", type.name().c_str()); 2449 } 2450 } 2451 2452 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) { 2453 if (p.fOperator == Token::MINUS) { 2454 SpvId result = this->nextId(); 2455 SpvId typeId = this->getType(p.fType); 2456 SpvId expr = this->writeExpression(*p.fOperand, out); 2457 if (is_float(fContext, p.fType)) { 2458 this->writeInstruction(SpvOpFNegate, typeId, result, expr, out); 2459 } else if (is_signed(fContext, p.fType)) { 2460 this->writeInstruction(SpvOpSNegate, typeId, result, expr, out); 2461 } else { 2462 ABORT("unsupported prefix expression %s", p.description().c_str()); 2463 }; 2464 return result; 2465 } 2466 switch (p.fOperator) { 2467 case Token::PLUS: 2468 return this->writeExpression(*p.fOperand, out); 2469 case Token::PLUSPLUS: { 2470 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2471 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2472 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, 2473 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, 2474 out); 2475 lv->store(result, out); 2476 return result; 2477 } 2478 case Token::MINUSMINUS: { 2479 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2480 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2481 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, 2482 SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, 2483 out); 2484 lv->store(result, out); 2485 return result; 2486 } 2487 case Token::LOGICALNOT: { 2488 ASSERT(p.fOperand->fType == *fContext.fBool_Type); 2489 SpvId result = this->nextId(); 2490 this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result, 2491 this->writeExpression(*p.fOperand, out), out); 2492 return result; 2493 } 2494 case Token::BITWISENOT: { 2495 SpvId result = this->nextId(); 2496 this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result, 2497 this->writeExpression(*p.fOperand, out), out); 2498 return result; 2499 } 2500 default: 2501 ABORT("unsupported prefix expression: %s", p.description().c_str()); 2502 } 2503 } 2504 2505 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) { 2506 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2507 SpvId result = lv->load(out); 2508 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2509 switch (p.fOperator) { 2510 case Token::PLUSPLUS: { 2511 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd, 2512 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 2513 lv->store(temp, out); 2514 return result; 2515 } 2516 case Token::MINUSMINUS: { 2517 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub, 2518 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2519 lv->store(temp, out); 2520 return result; 2521 } 2522 default: 2523 ABORT("unsupported postfix expression %s", p.description().c_str()); 2524 } 2525 } 2526 2527 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) { 2528 if (b.fValue) { 2529 if (fBoolTrue == 0) { 2530 fBoolTrue = this->nextId(); 2531 this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue, 2532 fConstantBuffer); 2533 } 2534 return fBoolTrue; 2535 } else { 2536 if (fBoolFalse == 0) { 2537 fBoolFalse = this->nextId(); 2538 this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse, 2539 fConstantBuffer); 2540 } 2541 return fBoolFalse; 2542 } 2543 } 2544 2545 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) { 2546 if (i.fType == *fContext.fInt_Type) { 2547 auto entry = fIntConstants.find(i.fValue); 2548 if (entry == fIntConstants.end()) { 2549 SpvId result = this->nextId(); 2550 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, 2551 fConstantBuffer); 2552 fIntConstants[i.fValue] = result; 2553 return result; 2554 } 2555 return entry->second; 2556 } else { 2557 ASSERT(i.fType == *fContext.fUInt_Type); 2558 auto entry = fUIntConstants.find(i.fValue); 2559 if (entry == fUIntConstants.end()) { 2560 SpvId result = this->nextId(); 2561 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, 2562 fConstantBuffer); 2563 fUIntConstants[i.fValue] = result; 2564 return result; 2565 } 2566 return entry->second; 2567 } 2568 } 2569 2570 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) { 2571 if (f.fType == *fContext.fFloat_Type) { 2572 float value = (float) f.fValue; 2573 auto entry = fFloatConstants.find(value); 2574 if (entry == fFloatConstants.end()) { 2575 SpvId result = this->nextId(); 2576 uint32_t bits; 2577 ASSERT(sizeof(bits) == sizeof(value)); 2578 memcpy(&bits, &value, sizeof(bits)); 2579 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits, 2580 fConstantBuffer); 2581 fFloatConstants[value] = result; 2582 return result; 2583 } 2584 return entry->second; 2585 } else { 2586 ASSERT(f.fType == *fContext.fDouble_Type); 2587 auto entry = fDoubleConstants.find(f.fValue); 2588 if (entry == fDoubleConstants.end()) { 2589 SpvId result = this->nextId(); 2590 uint64_t bits; 2591 ASSERT(sizeof(bits) == sizeof(f.fValue)); 2592 memcpy(&bits, &f.fValue, sizeof(bits)); 2593 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, 2594 bits & 0xffffffff, bits >> 32, fConstantBuffer); 2595 fDoubleConstants[f.fValue] = result; 2596 return result; 2597 } 2598 return entry->second; 2599 } 2600 } 2601 2602 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) { 2603 SpvId result = fFunctionMap[&f]; 2604 this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result, 2605 SpvFunctionControlMaskNone, this->getFunctionType(f), out); 2606 this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer); 2607 for (size_t i = 0; i < f.fParameters.size(); i++) { 2608 SpvId id = this->nextId(); 2609 fVariableMap[f.fParameters[i]] = id; 2610 SpvId type; 2611 type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction); 2612 this->writeInstruction(SpvOpFunctionParameter, type, id, out); 2613 } 2614 return result; 2615 } 2616 2617 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) { 2618 fVariableBuffer.reset(); 2619 SpvId result = this->writeFunctionStart(f.fDeclaration, out); 2620 this->writeLabel(this->nextId(), out); 2621 if (f.fDeclaration.fName == "main") { 2622 write_stringstream(fGlobalInitializersBuffer, out); 2623 } 2624 StringStream bodyBuffer; 2625 this->writeBlock((Block&) *f.fBody, bodyBuffer); 2626 write_stringstream(fVariableBuffer, out); 2627 write_stringstream(bodyBuffer, out); 2628 if (fCurrentBlock) { 2629 this->writeInstruction(SpvOpReturn, out); 2630 } 2631 this->writeInstruction(SpvOpFunctionEnd, out); 2632 return result; 2633 } 2634 2635 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) { 2636 if (layout.fLocation >= 0) { 2637 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation, 2638 fDecorationBuffer); 2639 } 2640 if (layout.fBinding >= 0) { 2641 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding, 2642 fDecorationBuffer); 2643 } 2644 if (layout.fIndex >= 0) { 2645 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex, 2646 fDecorationBuffer); 2647 } 2648 if (layout.fSet >= 0) { 2649 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet, 2650 fDecorationBuffer); 2651 } 2652 if (layout.fInputAttachmentIndex >= 0) { 2653 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex, 2654 layout.fInputAttachmentIndex, fDecorationBuffer); 2655 } 2656 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) { 2657 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin, 2658 fDecorationBuffer); 2659 } 2660 } 2661 2662 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) { 2663 if (layout.fLocation >= 0) { 2664 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation, 2665 layout.fLocation, fDecorationBuffer); 2666 } 2667 if (layout.fBinding >= 0) { 2668 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding, 2669 layout.fBinding, fDecorationBuffer); 2670 } 2671 if (layout.fIndex >= 0) { 2672 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex, 2673 layout.fIndex, fDecorationBuffer); 2674 } 2675 if (layout.fSet >= 0) { 2676 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet, 2677 layout.fSet, fDecorationBuffer); 2678 } 2679 if (layout.fInputAttachmentIndex >= 0) { 2680 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex, 2681 layout.fInputAttachmentIndex, fDecorationBuffer); 2682 } 2683 if (layout.fBuiltin >= 0) { 2684 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn, 2685 layout.fBuiltin, fDecorationBuffer); 2686 } 2687 } 2688 2689 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) { 2690 bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag)); 2691 MemoryLayout layout = (intf.fVariable.fModifiers.fLayout.fPushConstant || isBuffer) ? 2692 MemoryLayout(MemoryLayout::k430_Standard) : 2693 fDefaultLayout; 2694 SpvId result = this->nextId(); 2695 const Type* type = &intf.fVariable.fType; 2696 if (fProgram.fInputs.fRTHeight) { 2697 ASSERT(fRTHeightStructId == (SpvId) -1); 2698 ASSERT(fRTHeightFieldIndex == (SpvId) -1); 2699 std::vector<Type::Field> fields = type->fields(); 2700 fRTHeightStructId = result; 2701 fRTHeightFieldIndex = fields.size(); 2702 fields.emplace_back(Modifiers(), String(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get()); 2703 type = new Type(type->fPosition, type->name(), fields); 2704 } 2705 SpvId typeId = this->getType(*type, layout); 2706 if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) { 2707 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer); 2708 } else { 2709 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer); 2710 } 2711 SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers); 2712 SpvId ptrType = this->nextId(); 2713 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer); 2714 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer); 2715 this->writeLayout(intf.fVariable.fModifiers.fLayout, result); 2716 fVariableMap[&intf.fVariable] = result; 2717 if (fProgram.fInputs.fRTHeight) { 2718 delete type; 2719 } 2720 return result; 2721 } 2722 2723 void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) { 2724 if ((modifiers.fFlags & Modifiers::kLowp_Flag) | 2725 (modifiers.fFlags & Modifiers::kMediump_Flag)) { 2726 this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer); 2727 } 2728 } 2729 2730 #define BUILTIN_IGNORE 9999 2731 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl, 2732 OutputStream& out) { 2733 for (size_t i = 0; i < decl.fVars.size(); i++) { 2734 if (decl.fVars[i]->fKind == Statement::kNop_Kind) { 2735 continue; 2736 } 2737 const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i]; 2738 const Variable* var = varDecl.fVar; 2739 // These haven't been implemented in our SPIR-V generator yet and we only currently use them 2740 // in the OpenGL backend. 2741 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag | 2742 Modifiers::kWriteOnly_Flag | 2743 Modifiers::kCoherent_Flag | 2744 Modifiers::kVolatile_Flag | 2745 Modifiers::kRestrict_Flag))); 2746 if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) { 2747 continue; 2748 } 2749 if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN && 2750 kind != Program::kFragment_Kind) { 2751 continue; 2752 } 2753 if (!var->fReadCount && !var->fWriteCount && 2754 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag | 2755 Modifiers::kOut_Flag | 2756 Modifiers::kUniform_Flag | 2757 Modifiers::kBuffer_Flag))) { 2758 // variable is dead and not an input / output var (the Vulkan debug layers complain if 2759 // we elide an interface var, even if it's dead) 2760 continue; 2761 } 2762 SpvStorageClass_ storageClass; 2763 if (var->fModifiers.fFlags & Modifiers::kIn_Flag) { 2764 storageClass = SpvStorageClassInput; 2765 } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) { 2766 storageClass = SpvStorageClassOutput; 2767 } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) { 2768 if (var->fType.kind() == Type::kSampler_Kind) { 2769 storageClass = SpvStorageClassUniformConstant; 2770 } else { 2771 storageClass = SpvStorageClassUniform; 2772 } 2773 } else { 2774 storageClass = SpvStorageClassPrivate; 2775 } 2776 SpvId id = this->nextId(); 2777 fVariableMap[var] = id; 2778 SpvId type = this->getPointerType(var->fType, storageClass); 2779 this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer); 2780 this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer); 2781 this->writePrecisionModifier(var->fModifiers, id); 2782 if (var->fType.kind() == Type::kMatrix_Kind) { 2783 this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor, 2784 fDecorationBuffer); 2785 this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride, 2786 (SpvId) fDefaultLayout.stride(var->fType), fDecorationBuffer); 2787 } 2788 if (varDecl.fValue) { 2789 ASSERT(!fCurrentBlock); 2790 fCurrentBlock = -1; 2791 SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer); 2792 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer); 2793 fCurrentBlock = 0; 2794 } 2795 this->writeLayout(var->fModifiers.fLayout, id); 2796 } 2797 } 2798 2799 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) { 2800 for (const auto& stmt : decl.fVars) { 2801 ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind); 2802 VarDeclaration& varDecl = (VarDeclaration&) *stmt; 2803 const Variable* var = varDecl.fVar; 2804 // These haven't been implemented in our SPIR-V generator yet and we only currently use them 2805 // in the OpenGL backend. 2806 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag | 2807 Modifiers::kWriteOnly_Flag | 2808 Modifiers::kCoherent_Flag | 2809 Modifiers::kVolatile_Flag | 2810 Modifiers::kRestrict_Flag))); 2811 SpvId id = this->nextId(); 2812 fVariableMap[var] = id; 2813 SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction); 2814 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer); 2815 this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer); 2816 if (varDecl.fValue) { 2817 SpvId value = this->writeExpression(*varDecl.fValue, out); 2818 this->writeInstruction(SpvOpStore, id, value, out); 2819 } 2820 } 2821 } 2822 2823 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) { 2824 switch (s.fKind) { 2825 case Statement::kNop_Kind: 2826 break; 2827 case Statement::kBlock_Kind: 2828 this->writeBlock((Block&) s, out); 2829 break; 2830 case Statement::kExpression_Kind: 2831 this->writeExpression(*((ExpressionStatement&) s).fExpression, out); 2832 break; 2833 case Statement::kReturn_Kind: 2834 this->writeReturnStatement((ReturnStatement&) s, out); 2835 break; 2836 case Statement::kVarDeclarations_Kind: 2837 this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out); 2838 break; 2839 case Statement::kIf_Kind: 2840 this->writeIfStatement((IfStatement&) s, out); 2841 break; 2842 case Statement::kFor_Kind: 2843 this->writeForStatement((ForStatement&) s, out); 2844 break; 2845 case Statement::kWhile_Kind: 2846 this->writeWhileStatement((WhileStatement&) s, out); 2847 break; 2848 case Statement::kDo_Kind: 2849 this->writeDoStatement((DoStatement&) s, out); 2850 break; 2851 case Statement::kBreak_Kind: 2852 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out); 2853 break; 2854 case Statement::kContinue_Kind: 2855 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out); 2856 break; 2857 case Statement::kDiscard_Kind: 2858 this->writeInstruction(SpvOpKill, out); 2859 break; 2860 default: 2861 ABORT("unsupported statement: %s", s.description().c_str()); 2862 } 2863 } 2864 2865 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) { 2866 for (size_t i = 0; i < b.fStatements.size(); i++) { 2867 this->writeStatement(*b.fStatements[i], out); 2868 } 2869 } 2870 2871 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) { 2872 SpvId test = this->writeExpression(*stmt.fTest, out); 2873 SpvId ifTrue = this->nextId(); 2874 SpvId ifFalse = this->nextId(); 2875 if (stmt.fIfFalse) { 2876 SpvId end = this->nextId(); 2877 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2878 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out); 2879 this->writeLabel(ifTrue, out); 2880 this->writeStatement(*stmt.fIfTrue, out); 2881 if (fCurrentBlock) { 2882 this->writeInstruction(SpvOpBranch, end, out); 2883 } 2884 this->writeLabel(ifFalse, out); 2885 this->writeStatement(*stmt.fIfFalse, out); 2886 if (fCurrentBlock) { 2887 this->writeInstruction(SpvOpBranch, end, out); 2888 } 2889 this->writeLabel(end, out); 2890 } else { 2891 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out); 2892 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out); 2893 this->writeLabel(ifTrue, out); 2894 this->writeStatement(*stmt.fIfTrue, out); 2895 if (fCurrentBlock) { 2896 this->writeInstruction(SpvOpBranch, ifFalse, out); 2897 } 2898 this->writeLabel(ifFalse, out); 2899 } 2900 } 2901 2902 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) { 2903 if (f.fInitializer) { 2904 this->writeStatement(*f.fInitializer, out); 2905 } 2906 SpvId header = this->nextId(); 2907 SpvId start = this->nextId(); 2908 SpvId body = this->nextId(); 2909 SpvId next = this->nextId(); 2910 fContinueTarget.push(next); 2911 SpvId end = this->nextId(); 2912 fBreakTarget.push(end); 2913 this->writeInstruction(SpvOpBranch, header, out); 2914 this->writeLabel(header, out); 2915 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out); 2916 this->writeInstruction(SpvOpBranch, start, out); 2917 this->writeLabel(start, out); 2918 if (f.fTest) { 2919 SpvId test = this->writeExpression(*f.fTest, out); 2920 this->writeInstruction(SpvOpBranchConditional, test, body, end, out); 2921 } 2922 this->writeLabel(body, out); 2923 this->writeStatement(*f.fStatement, out); 2924 if (fCurrentBlock) { 2925 this->writeInstruction(SpvOpBranch, next, out); 2926 } 2927 this->writeLabel(next, out); 2928 if (f.fNext) { 2929 this->writeExpression(*f.fNext, out); 2930 } 2931 this->writeInstruction(SpvOpBranch, header, out); 2932 this->writeLabel(end, out); 2933 fBreakTarget.pop(); 2934 fContinueTarget.pop(); 2935 } 2936 2937 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) { 2938 // We believe the while loop code below will work, but Skia doesn't actually use them and 2939 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For 2940 // the time being, we just fail with an error due to the lack of testing. If you encounter this 2941 // message, simply remove the error call below to see whether our while loop support actually 2942 // works. 2943 fErrors.error(w.fPosition, "internal error: while loop support has been disabled in SPIR-V, " 2944 "see SkSLSPIRVCodeGenerator.cpp for details"); 2945 2946 SpvId header = this->nextId(); 2947 SpvId start = this->nextId(); 2948 SpvId body = this->nextId(); 2949 fContinueTarget.push(start); 2950 SpvId end = this->nextId(); 2951 fBreakTarget.push(end); 2952 this->writeInstruction(SpvOpBranch, header, out); 2953 this->writeLabel(header, out); 2954 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out); 2955 this->writeInstruction(SpvOpBranch, start, out); 2956 this->writeLabel(start, out); 2957 SpvId test = this->writeExpression(*w.fTest, out); 2958 this->writeInstruction(SpvOpBranchConditional, test, body, end, out); 2959 this->writeLabel(body, out); 2960 this->writeStatement(*w.fStatement, out); 2961 if (fCurrentBlock) { 2962 this->writeInstruction(SpvOpBranch, start, out); 2963 } 2964 this->writeLabel(end, out); 2965 fBreakTarget.pop(); 2966 fContinueTarget.pop(); 2967 } 2968 2969 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) { 2970 // We believe the do loop code below will work, but Skia doesn't actually use them and 2971 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For 2972 // the time being, we just fail with an error due to the lack of testing. If you encounter this 2973 // message, simply remove the error call below to see whether our do loop support actually 2974 // works. 2975 fErrors.error(d.fPosition, "internal error: do loop support has been disabled in SPIR-V, see " 2976 "SkSLSPIRVCodeGenerator.cpp for details"); 2977 2978 SpvId header = this->nextId(); 2979 SpvId start = this->nextId(); 2980 SpvId next = this->nextId(); 2981 fContinueTarget.push(next); 2982 SpvId end = this->nextId(); 2983 fBreakTarget.push(end); 2984 this->writeInstruction(SpvOpBranch, header, out); 2985 this->writeLabel(header, out); 2986 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out); 2987 this->writeInstruction(SpvOpBranch, start, out); 2988 this->writeLabel(start, out); 2989 this->writeStatement(*d.fStatement, out); 2990 if (fCurrentBlock) { 2991 this->writeInstruction(SpvOpBranch, next, out); 2992 } 2993 this->writeLabel(next, out); 2994 SpvId test = this->writeExpression(*d.fTest, out); 2995 this->writeInstruction(SpvOpBranchConditional, test, start, end, out); 2996 this->writeLabel(end, out); 2997 fBreakTarget.pop(); 2998 fContinueTarget.pop(); 2999 } 3000 3001 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) { 3002 if (r.fExpression) { 3003 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out), 3004 out); 3005 } else { 3006 this->writeInstruction(SpvOpReturn, out); 3007 } 3008 } 3009 3010 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) { 3011 fGLSLExtendedInstructions = this->nextId(); 3012 StringStream body; 3013 std::set<SpvId> interfaceVars; 3014 // assign IDs to functions 3015 for (size_t i = 0; i < program.fElements.size(); i++) { 3016 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) { 3017 FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i]; 3018 fFunctionMap[&f.fDeclaration] = this->nextId(); 3019 } 3020 } 3021 for (size_t i = 0; i < program.fElements.size(); i++) { 3022 if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) { 3023 InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i]; 3024 SpvId id = this->writeInterfaceBlock(intf); 3025 if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) || 3026 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) { 3027 interfaceVars.insert(id); 3028 } 3029 } 3030 } 3031 for (size_t i = 0; i < program.fElements.size(); i++) { 3032 if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) { 3033 this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]), 3034 body); 3035 } 3036 } 3037 for (size_t i = 0; i < program.fElements.size(); i++) { 3038 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) { 3039 this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body); 3040 } 3041 } 3042 const FunctionDeclaration* main = nullptr; 3043 for (auto entry : fFunctionMap) { 3044 if (entry.first->fName == "main") { 3045 main = entry.first; 3046 } 3047 } 3048 ASSERT(main); 3049 for (auto entry : fVariableMap) { 3050 const Variable* var = entry.first; 3051 if (var->fStorage == Variable::kGlobal_Storage && 3052 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) || 3053 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) { 3054 interfaceVars.insert(entry.second); 3055 } 3056 } 3057 this->writeCapabilities(out); 3058 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out); 3059 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out); 3060 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (strlen(main->fName.c_str()) + 4) / 4) + 3061 (int32_t) interfaceVars.size(), out); 3062 switch (program.fKind) { 3063 case Program::kVertex_Kind: 3064 this->writeWord(SpvExecutionModelVertex, out); 3065 break; 3066 case Program::kFragment_Kind: 3067 this->writeWord(SpvExecutionModelFragment, out); 3068 break; 3069 case Program::kGeometry_Kind: 3070 this->writeWord(SpvExecutionModelGeometry, out); 3071 break; 3072 default: 3073 ABORT("cannot write this kind of program to SPIR-V\n"); 3074 } 3075 this->writeWord(fFunctionMap[main], out); 3076 this->writeString(main->fName.c_str(), out); 3077 for (int var : interfaceVars) { 3078 this->writeWord(var, out); 3079 } 3080 if (program.fKind == Program::kFragment_Kind) { 3081 this->writeInstruction(SpvOpExecutionMode, 3082 fFunctionMap[main], 3083 SpvExecutionModeOriginUpperLeft, 3084 out); 3085 } 3086 for (size_t i = 0; i < program.fElements.size(); i++) { 3087 if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) { 3088 this->writeInstruction(SpvOpSourceExtension, 3089 ((Extension&) *program.fElements[i]).fName.c_str(), 3090 out); 3091 } 3092 } 3093 3094 write_stringstream(fExtraGlobalsBuffer, out); 3095 write_stringstream(fNameBuffer, out); 3096 write_stringstream(fDecorationBuffer, out); 3097 write_stringstream(fConstantBuffer, out); 3098 write_stringstream(fExternalFunctionsBuffer, out); 3099 write_stringstream(body, out); 3100 } 3101 3102 bool SPIRVCodeGenerator::generateCode() { 3103 ASSERT(!fErrors.errorCount()); 3104 this->writeWord(SpvMagicNumber, *fOut); 3105 this->writeWord(SpvVersion, *fOut); 3106 this->writeWord(SKSL_MAGIC, *fOut); 3107 StringStream buffer; 3108 this->writeInstructions(fProgram, buffer); 3109 this->writeWord(fIdCount, *fOut); 3110 this->writeWord(0, *fOut); // reserved, always zero 3111 write_stringstream(buffer, *fOut); 3112 return 0 == fErrors.errorCount(); 3113 } 3114 3115 } 3116