1 /* 2 * Copyright 2016 Google Inc. 3 * 4 * Use of this source code is governed by a BSD-style license that can be 5 * found in the LICENSE file. 6 */ 7 8 #include "SkSLSPIRVCodeGenerator.h" 9 10 #include "GLSL.std.450.h" 11 12 #include "ir/SkSLExpressionStatement.h" 13 #include "ir/SkSLExtension.h" 14 #include "ir/SkSLIndexExpression.h" 15 #include "ir/SkSLVariableReference.h" 16 #include "SkSLCompiler.h" 17 18 namespace SkSL { 19 20 static const int32_t SKSL_MAGIC = 0x0; // FIXME: we should probably register a magic number 21 22 void SPIRVCodeGenerator::setupIntrinsics() { 23 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \ 24 GLSLstd450 ## x, GLSLstd450 ## x) 25 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \ 26 GLSLstd450 ## ifFloat, \ 27 GLSLstd450 ## ifInt, \ 28 GLSLstd450 ## ifUInt, \ 29 SpvOpUndef) 30 #define ALL_SPIRV(x) std::make_tuple(kSPIRV_IntrinsicKind, SpvOp ## x, SpvOp ## x, SpvOp ## x, \ 31 SpvOp ## x) 32 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \ 33 k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \ 34 k ## x ## _SpecialIntrinsic) 35 fIntrinsicMap[String("round")] = ALL_GLSL(Round); 36 fIntrinsicMap[String("roundEven")] = ALL_GLSL(RoundEven); 37 fIntrinsicMap[String("trunc")] = ALL_GLSL(Trunc); 38 fIntrinsicMap[String("abs")] = BY_TYPE_GLSL(FAbs, SAbs, SAbs); 39 fIntrinsicMap[String("sign")] = BY_TYPE_GLSL(FSign, SSign, SSign); 40 fIntrinsicMap[String("floor")] = ALL_GLSL(Floor); 41 fIntrinsicMap[String("ceil")] = ALL_GLSL(Ceil); 42 fIntrinsicMap[String("fract")] = ALL_GLSL(Fract); 43 fIntrinsicMap[String("radians")] = ALL_GLSL(Radians); 44 fIntrinsicMap[String("degrees")] = ALL_GLSL(Degrees); 45 fIntrinsicMap[String("sin")] = ALL_GLSL(Sin); 46 fIntrinsicMap[String("cos")] = ALL_GLSL(Cos); 47 fIntrinsicMap[String("tan")] = ALL_GLSL(Tan); 48 fIntrinsicMap[String("asin")] = ALL_GLSL(Asin); 49 fIntrinsicMap[String("acos")] = ALL_GLSL(Acos); 50 fIntrinsicMap[String("atan")] = SPECIAL(Atan); 51 fIntrinsicMap[String("sinh")] = ALL_GLSL(Sinh); 52 fIntrinsicMap[String("cosh")] = ALL_GLSL(Cosh); 53 fIntrinsicMap[String("tanh")] = ALL_GLSL(Tanh); 54 fIntrinsicMap[String("asinh")] = ALL_GLSL(Asinh); 55 fIntrinsicMap[String("acosh")] = ALL_GLSL(Acosh); 56 fIntrinsicMap[String("atanh")] = ALL_GLSL(Atanh); 57 fIntrinsicMap[String("pow")] = ALL_GLSL(Pow); 58 fIntrinsicMap[String("exp")] = ALL_GLSL(Exp); 59 fIntrinsicMap[String("log")] = ALL_GLSL(Log); 60 fIntrinsicMap[String("exp2")] = ALL_GLSL(Exp2); 61 fIntrinsicMap[String("log2")] = ALL_GLSL(Log2); 62 fIntrinsicMap[String("sqrt")] = ALL_GLSL(Sqrt); 63 fIntrinsicMap[String("inverse")] = ALL_GLSL(MatrixInverse); 64 fIntrinsicMap[String("transpose")] = ALL_SPIRV(Transpose); 65 fIntrinsicMap[String("inversesqrt")] = ALL_GLSL(InverseSqrt); 66 fIntrinsicMap[String("determinant")] = ALL_GLSL(Determinant); 67 fIntrinsicMap[String("matrixInverse")] = ALL_GLSL(MatrixInverse); 68 fIntrinsicMap[String("mod")] = SPECIAL(Mod); 69 fIntrinsicMap[String("min")] = SPECIAL(Min); 70 fIntrinsicMap[String("max")] = SPECIAL(Max); 71 fIntrinsicMap[String("clamp")] = SPECIAL(Clamp); 72 fIntrinsicMap[String("dot")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot, 73 SpvOpUndef, SpvOpUndef, SpvOpUndef); 74 fIntrinsicMap[String("mix")] = SPECIAL(Mix); 75 fIntrinsicMap[String("step")] = ALL_GLSL(Step); 76 fIntrinsicMap[String("smoothstep")] = ALL_GLSL(SmoothStep); 77 fIntrinsicMap[String("fma")] = ALL_GLSL(Fma); 78 fIntrinsicMap[String("frexp")] = ALL_GLSL(Frexp); 79 fIntrinsicMap[String("ldexp")] = ALL_GLSL(Ldexp); 80 81 #define PACK(type) fIntrinsicMap[String("pack" #type)] = ALL_GLSL(Pack ## type); \ 82 fIntrinsicMap[String("unpack" #type)] = ALL_GLSL(Unpack ## type) 83 PACK(Snorm4x8); 84 PACK(Unorm4x8); 85 PACK(Snorm2x16); 86 PACK(Unorm2x16); 87 PACK(Half2x16); 88 PACK(Double2x32); 89 fIntrinsicMap[String("length")] = ALL_GLSL(Length); 90 fIntrinsicMap[String("distance")] = ALL_GLSL(Distance); 91 fIntrinsicMap[String("cross")] = ALL_GLSL(Cross); 92 fIntrinsicMap[String("normalize")] = ALL_GLSL(Normalize); 93 fIntrinsicMap[String("faceForward")] = ALL_GLSL(FaceForward); 94 fIntrinsicMap[String("reflect")] = ALL_GLSL(Reflect); 95 fIntrinsicMap[String("refract")] = ALL_GLSL(Refract); 96 fIntrinsicMap[String("findLSB")] = ALL_GLSL(FindILsb); 97 fIntrinsicMap[String("findMSB")] = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb); 98 fIntrinsicMap[String("dFdx")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx, 99 SpvOpUndef, SpvOpUndef, SpvOpUndef); 100 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy, 101 SpvOpUndef, SpvOpUndef, SpvOpUndef); 102 fIntrinsicMap[String("dFdy")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy, 103 SpvOpUndef, SpvOpUndef, SpvOpUndef); 104 fIntrinsicMap[String("texture")] = SPECIAL(Texture); 105 fIntrinsicMap[String("texelFetch")] = SPECIAL(TexelFetch); 106 fIntrinsicMap[String("subpassLoad")] = SPECIAL(SubpassLoad); 107 108 fIntrinsicMap[String("any")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef, 109 SpvOpUndef, SpvOpUndef, SpvOpAny); 110 fIntrinsicMap[String("all")] = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef, 111 SpvOpUndef, SpvOpUndef, SpvOpAll); 112 fIntrinsicMap[String("equal")] = std::make_tuple(kSPIRV_IntrinsicKind, 113 SpvOpFOrdEqual, SpvOpIEqual, 114 SpvOpIEqual, SpvOpLogicalEqual); 115 fIntrinsicMap[String("notEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 116 SpvOpFOrdNotEqual, SpvOpINotEqual, 117 SpvOpINotEqual, 118 SpvOpLogicalNotEqual); 119 fIntrinsicMap[String("lessThan")] = std::make_tuple(kSPIRV_IntrinsicKind, 120 SpvOpFOrdLessThan, SpvOpSLessThan, 121 SpvOpULessThan, SpvOpUndef); 122 fIntrinsicMap[String("lessThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 123 SpvOpFOrdLessThanEqual, 124 SpvOpSLessThanEqual, 125 SpvOpULessThanEqual, 126 SpvOpUndef); 127 fIntrinsicMap[String("greaterThan")] = std::make_tuple(kSPIRV_IntrinsicKind, 128 SpvOpFOrdGreaterThan, 129 SpvOpSGreaterThan, 130 SpvOpUGreaterThan, 131 SpvOpUndef); 132 fIntrinsicMap[String("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind, 133 SpvOpFOrdGreaterThanEqual, 134 SpvOpSGreaterThanEqual, 135 SpvOpUGreaterThanEqual, 136 SpvOpUndef); 137 fIntrinsicMap[String("EmitVertex")] = ALL_SPIRV(EmitVertex); 138 fIntrinsicMap[String("EndPrimitive")] = ALL_SPIRV(EndPrimitive); 139 // interpolateAt* not yet supported... 140 } 141 142 void SPIRVCodeGenerator::writeWord(int32_t word, OutputStream& out) { 143 out.write((const char*) &word, sizeof(word)); 144 } 145 146 static bool is_float(const Context& context, const Type& type) { 147 if (type.kind() == Type::kVector_Kind) { 148 return is_float(context, type.componentType()); 149 } 150 return type == *context.fFloat_Type || type == *context.fHalf_Type || 151 type == *context.fDouble_Type; 152 } 153 154 static bool is_signed(const Context& context, const Type& type) { 155 if (type.kind() == Type::kVector_Kind) { 156 return is_signed(context, type.componentType()); 157 } 158 return type == *context.fInt_Type || type == *context.fShort_Type; 159 } 160 161 static bool is_unsigned(const Context& context, const Type& type) { 162 if (type.kind() == Type::kVector_Kind) { 163 return is_unsigned(context, type.componentType()); 164 } 165 return type == *context.fUInt_Type || type == *context.fUShort_Type; 166 } 167 168 static bool is_bool(const Context& context, const Type& type) { 169 if (type.kind() == Type::kVector_Kind) { 170 return is_bool(context, type.componentType()); 171 } 172 return type == *context.fBool_Type; 173 } 174 175 static bool is_out(const Variable& var) { 176 return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0; 177 } 178 179 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, OutputStream& out) { 180 ASSERT(opCode != SpvOpLoad || &out != &fConstantBuffer); 181 ASSERT(opCode != SpvOpUndef); 182 switch (opCode) { 183 case SpvOpReturn: // fall through 184 case SpvOpReturnValue: // fall through 185 case SpvOpKill: // fall through 186 case SpvOpBranch: // fall through 187 case SpvOpBranchConditional: 188 ASSERT(fCurrentBlock); 189 fCurrentBlock = 0; 190 break; 191 case SpvOpConstant: // fall through 192 case SpvOpConstantTrue: // fall through 193 case SpvOpConstantFalse: // fall through 194 case SpvOpConstantComposite: // fall through 195 case SpvOpTypeVoid: // fall through 196 case SpvOpTypeInt: // fall through 197 case SpvOpTypeFloat: // fall through 198 case SpvOpTypeBool: // fall through 199 case SpvOpTypeVector: // fall through 200 case SpvOpTypeMatrix: // fall through 201 case SpvOpTypeArray: // fall through 202 case SpvOpTypePointer: // fall through 203 case SpvOpTypeFunction: // fall through 204 case SpvOpTypeRuntimeArray: // fall through 205 case SpvOpTypeStruct: // fall through 206 case SpvOpTypeImage: // fall through 207 case SpvOpTypeSampledImage: // fall through 208 case SpvOpVariable: // fall through 209 case SpvOpFunction: // fall through 210 case SpvOpFunctionParameter: // fall through 211 case SpvOpFunctionEnd: // fall through 212 case SpvOpExecutionMode: // fall through 213 case SpvOpMemoryModel: // fall through 214 case SpvOpCapability: // fall through 215 case SpvOpExtInstImport: // fall through 216 case SpvOpEntryPoint: // fall through 217 case SpvOpSource: // fall through 218 case SpvOpSourceExtension: // fall through 219 case SpvOpName: // fall through 220 case SpvOpMemberName: // fall through 221 case SpvOpDecorate: // fall through 222 case SpvOpMemberDecorate: 223 break; 224 default: 225 ASSERT(fCurrentBlock); 226 } 227 this->writeWord((length << 16) | opCode, out); 228 } 229 230 void SPIRVCodeGenerator::writeLabel(SpvId label, OutputStream& out) { 231 fCurrentBlock = label; 232 this->writeInstruction(SpvOpLabel, label, out); 233 } 234 235 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, OutputStream& out) { 236 this->writeOpCode(opCode, 1, out); 237 } 238 239 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, OutputStream& out) { 240 this->writeOpCode(opCode, 2, out); 241 this->writeWord(word1, out); 242 } 243 244 void SPIRVCodeGenerator::writeString(const char* string, size_t length, OutputStream& out) { 245 out.write(string, length); 246 switch (length % 4) { 247 case 1: 248 out.write8(0); 249 // fall through 250 case 2: 251 out.write8(0); 252 // fall through 253 case 3: 254 out.write8(0); 255 break; 256 default: 257 this->writeWord(0, out); 258 } 259 } 260 261 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, StringFragment string, OutputStream& out) { 262 this->writeOpCode(opCode, 1 + (string.fLength + 4) / 4, out); 263 this->writeString(string.fChars, string.fLength, out); 264 } 265 266 267 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, StringFragment string, 268 OutputStream& out) { 269 this->writeOpCode(opCode, 2 + (string.fLength + 4) / 4, out); 270 this->writeWord(word1, out); 271 this->writeString(string.fChars, string.fLength, out); 272 } 273 274 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 275 StringFragment string, OutputStream& out) { 276 this->writeOpCode(opCode, 3 + (string.fLength + 4) / 4, out); 277 this->writeWord(word1, out); 278 this->writeWord(word2, out); 279 this->writeString(string.fChars, string.fLength, out); 280 } 281 282 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 283 OutputStream& out) { 284 this->writeOpCode(opCode, 3, out); 285 this->writeWord(word1, out); 286 this->writeWord(word2, out); 287 } 288 289 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 290 int32_t word3, OutputStream& out) { 291 this->writeOpCode(opCode, 4, out); 292 this->writeWord(word1, out); 293 this->writeWord(word2, out); 294 this->writeWord(word3, out); 295 } 296 297 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 298 int32_t word3, int32_t word4, OutputStream& out) { 299 this->writeOpCode(opCode, 5, out); 300 this->writeWord(word1, out); 301 this->writeWord(word2, out); 302 this->writeWord(word3, out); 303 this->writeWord(word4, out); 304 } 305 306 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 307 int32_t word3, int32_t word4, int32_t word5, 308 OutputStream& out) { 309 this->writeOpCode(opCode, 6, out); 310 this->writeWord(word1, out); 311 this->writeWord(word2, out); 312 this->writeWord(word3, out); 313 this->writeWord(word4, out); 314 this->writeWord(word5, out); 315 } 316 317 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 318 int32_t word3, int32_t word4, int32_t word5, 319 int32_t word6, OutputStream& out) { 320 this->writeOpCode(opCode, 7, out); 321 this->writeWord(word1, out); 322 this->writeWord(word2, out); 323 this->writeWord(word3, out); 324 this->writeWord(word4, out); 325 this->writeWord(word5, out); 326 this->writeWord(word6, out); 327 } 328 329 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 330 int32_t word3, int32_t word4, int32_t word5, 331 int32_t word6, int32_t word7, OutputStream& out) { 332 this->writeOpCode(opCode, 8, out); 333 this->writeWord(word1, out); 334 this->writeWord(word2, out); 335 this->writeWord(word3, out); 336 this->writeWord(word4, out); 337 this->writeWord(word5, out); 338 this->writeWord(word6, out); 339 this->writeWord(word7, out); 340 } 341 342 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2, 343 int32_t word3, int32_t word4, int32_t word5, 344 int32_t word6, int32_t word7, int32_t word8, 345 OutputStream& out) { 346 this->writeOpCode(opCode, 9, out); 347 this->writeWord(word1, out); 348 this->writeWord(word2, out); 349 this->writeWord(word3, out); 350 this->writeWord(word4, out); 351 this->writeWord(word5, out); 352 this->writeWord(word6, out); 353 this->writeWord(word7, out); 354 this->writeWord(word8, out); 355 } 356 357 void SPIRVCodeGenerator::writeCapabilities(OutputStream& out) { 358 for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) { 359 if (fCapabilities & bit) { 360 this->writeInstruction(SpvOpCapability, (SpvId) i, out); 361 } 362 } 363 if (fProgram.fKind == Program::kGeometry_Kind) { 364 this->writeInstruction(SpvOpCapability, SpvCapabilityGeometry, out); 365 } 366 } 367 368 SpvId SPIRVCodeGenerator::nextId() { 369 return fIdCount++; 370 } 371 372 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout, 373 SpvId resultId) { 374 this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer); 375 // go ahead and write all of the field types, so we don't inadvertently write them while we're 376 // in the middle of writing the struct instruction 377 std::vector<SpvId> types; 378 for (const auto& f : type.fields()) { 379 types.push_back(this->getType(*f.fType, memoryLayout)); 380 } 381 this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer); 382 this->writeWord(resultId, fConstantBuffer); 383 for (SpvId id : types) { 384 this->writeWord(id, fConstantBuffer); 385 } 386 size_t offset = 0; 387 for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) { 388 size_t size = memoryLayout.size(*type.fields()[i].fType); 389 size_t alignment = memoryLayout.alignment(*type.fields()[i].fType); 390 const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout; 391 if (fieldLayout.fOffset >= 0) { 392 if (fieldLayout.fOffset < (int) offset) { 393 fErrors.error(type.fOffset, 394 "offset of field '" + type.fields()[i].fName + "' must be at " 395 "least " + to_string((int) offset)); 396 } 397 if (fieldLayout.fOffset % alignment) { 398 fErrors.error(type.fOffset, 399 "offset of field '" + type.fields()[i].fName + "' must be a multiple" 400 " of " + to_string((int) alignment)); 401 } 402 offset = fieldLayout.fOffset; 403 } else { 404 size_t mod = offset % alignment; 405 if (mod) { 406 offset += alignment - mod; 407 } 408 } 409 this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName, fNameBuffer); 410 this->writeLayout(fieldLayout, resultId, i); 411 if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) { 412 this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset, 413 (SpvId) offset, fDecorationBuffer); 414 } 415 if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) { 416 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor, 417 fDecorationBuffer); 418 this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride, 419 (SpvId) memoryLayout.stride(*type.fields()[i].fType), 420 fDecorationBuffer); 421 } 422 offset += size; 423 Type::Kind kind = type.fields()[i].fType->kind(); 424 if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) { 425 offset += alignment - offset % alignment; 426 } 427 } 428 } 429 430 Type SPIRVCodeGenerator::getActualType(const Type& type) { 431 if (type == *fContext.fHalf_Type) { 432 return *fContext.fFloat_Type; 433 } 434 if (type == *fContext.fShort_Type) { 435 return *fContext.fInt_Type; 436 } 437 if (type == *fContext.fUShort_Type) { 438 return *fContext.fUInt_Type; 439 } 440 if (type.kind() == Type::kMatrix_Kind || type.kind() == Type::kVector_Kind) { 441 if (type.componentType() == *fContext.fHalf_Type) { 442 return fContext.fFloat_Type->toCompound(fContext, type.columns(), type.rows()); 443 } 444 if (type.componentType() == *fContext.fShort_Type) { 445 return fContext.fInt_Type->toCompound(fContext, type.columns(), type.rows()); 446 } 447 if (type.componentType() == *fContext.fUShort_Type) { 448 return fContext.fUInt_Type->toCompound(fContext, type.columns(), type.rows()); 449 } 450 } 451 return type; 452 } 453 454 SpvId SPIRVCodeGenerator::getType(const Type& type) { 455 return this->getType(type, fDefaultLayout); 456 } 457 458 SpvId SPIRVCodeGenerator::getType(const Type& rawType, const MemoryLayout& layout) { 459 Type type = this->getActualType(rawType); 460 String key = type.name() + to_string((int) layout.fStd); 461 auto entry = fTypeMap.find(key); 462 if (entry == fTypeMap.end()) { 463 SpvId result = this->nextId(); 464 switch (type.kind()) { 465 case Type::kScalar_Kind: 466 if (type == *fContext.fBool_Type) { 467 this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer); 468 } else if (type == *fContext.fInt_Type) { 469 this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer); 470 } else if (type == *fContext.fUInt_Type) { 471 this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer); 472 } else if (type == *fContext.fFloat_Type) { 473 this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer); 474 } else if (type == *fContext.fDouble_Type) { 475 this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer); 476 } else { 477 ASSERT(false); 478 } 479 break; 480 case Type::kVector_Kind: 481 this->writeInstruction(SpvOpTypeVector, result, 482 this->getType(type.componentType(), layout), 483 type.columns(), fConstantBuffer); 484 break; 485 case Type::kMatrix_Kind: 486 this->writeInstruction(SpvOpTypeMatrix, result, 487 this->getType(index_type(fContext, type), layout), 488 type.columns(), fConstantBuffer); 489 break; 490 case Type::kStruct_Kind: 491 this->writeStruct(type, layout, result); 492 break; 493 case Type::kArray_Kind: { 494 if (type.columns() > 0) { 495 IntLiteral count(fContext, -1, type.columns()); 496 this->writeInstruction(SpvOpTypeArray, result, 497 this->getType(type.componentType(), layout), 498 this->writeIntLiteral(count), fConstantBuffer); 499 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 500 (int32_t) layout.stride(type), 501 fDecorationBuffer); 502 } else { 503 this->writeInstruction(SpvOpTypeRuntimeArray, result, 504 this->getType(type.componentType(), layout), 505 fConstantBuffer); 506 this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride, 507 (int32_t) layout.stride(type), 508 fDecorationBuffer); 509 } 510 break; 511 } 512 case Type::kSampler_Kind: { 513 SpvId image = result; 514 if (SpvDimSubpassData != type.dimensions()) { 515 image = this->nextId(); 516 } 517 if (SpvDimBuffer == type.dimensions()) { 518 fCapabilities |= (((uint64_t) 1) << SpvCapabilitySampledBuffer); 519 } 520 this->writeInstruction(SpvOpTypeImage, image, 521 this->getType(*fContext.fFloat_Type, layout), 522 type.dimensions(), type.isDepth(), type.isArrayed(), 523 type.isMultisampled(), type.isSampled() ? 1 : 2, 524 SpvImageFormatUnknown, fConstantBuffer); 525 fImageTypeMap[key] = image; 526 if (SpvDimSubpassData != type.dimensions()) { 527 this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer); 528 } 529 break; 530 } 531 default: 532 if (type == *fContext.fVoid_Type) { 533 this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer); 534 } else { 535 ABORT("invalid type: %s", type.description().c_str()); 536 } 537 } 538 fTypeMap[key] = result; 539 return result; 540 } 541 return entry->second; 542 } 543 544 SpvId SPIRVCodeGenerator::getImageType(const Type& type) { 545 ASSERT(type.kind() == Type::kSampler_Kind); 546 this->getType(type); 547 String key = type.name() + to_string((int) fDefaultLayout.fStd); 548 ASSERT(fImageTypeMap.find(key) != fImageTypeMap.end()); 549 return fImageTypeMap[key]; 550 } 551 552 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) { 553 String key = function.fReturnType.description() + "("; 554 String separator; 555 for (size_t i = 0; i < function.fParameters.size(); i++) { 556 key += separator; 557 separator = ", "; 558 key += function.fParameters[i]->fType.description(); 559 } 560 key += ")"; 561 auto entry = fTypeMap.find(key); 562 if (entry == fTypeMap.end()) { 563 SpvId result = this->nextId(); 564 int32_t length = 3 + (int32_t) function.fParameters.size(); 565 SpvId returnType = this->getType(function.fReturnType); 566 std::vector<SpvId> parameterTypes; 567 for (size_t i = 0; i < function.fParameters.size(); i++) { 568 // glslang seems to treat all function arguments as pointers whether they need to be or 569 // not. I was initially puzzled by this until I ran bizarre failures with certain 570 // patterns of function calls and control constructs, as exemplified by this minimal 571 // failure case: 572 // 573 // void sphere(float x) { 574 // } 575 // 576 // void map() { 577 // sphere(1.0); 578 // } 579 // 580 // void main() { 581 // for (int i = 0; i < 1; i++) { 582 // map(); 583 // } 584 // } 585 // 586 // As of this writing, compiling this in the "obvious" way (with sphere taking a float) 587 // crashes. Making it take a float* and storing the argument in a temporary variable, 588 // as glslang does, fixes it. It's entirely possible I simply missed whichever part of 589 // the spec makes this make sense. 590 // if (is_out(function->fParameters[i])) { 591 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType, 592 SpvStorageClassFunction)); 593 // } else { 594 // parameterTypes.push_back(this->getType(function.fParameters[i]->fType)); 595 // } 596 } 597 this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer); 598 this->writeWord(result, fConstantBuffer); 599 this->writeWord(returnType, fConstantBuffer); 600 for (SpvId id : parameterTypes) { 601 this->writeWord(id, fConstantBuffer); 602 } 603 fTypeMap[key] = result; 604 return result; 605 } 606 return entry->second; 607 } 608 609 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) { 610 return this->getPointerType(type, fDefaultLayout, storageClass); 611 } 612 613 SpvId SPIRVCodeGenerator::getPointerType(const Type& rawType, const MemoryLayout& layout, 614 SpvStorageClass_ storageClass) { 615 Type type = this->getActualType(rawType); 616 String key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass); 617 auto entry = fTypeMap.find(key); 618 if (entry == fTypeMap.end()) { 619 SpvId result = this->nextId(); 620 this->writeInstruction(SpvOpTypePointer, result, storageClass, 621 this->getType(type), fConstantBuffer); 622 fTypeMap[key] = result; 623 return result; 624 } 625 return entry->second; 626 } 627 628 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, OutputStream& out) { 629 switch (expr.fKind) { 630 case Expression::kBinary_Kind: 631 return this->writeBinaryExpression((BinaryExpression&) expr, out); 632 case Expression::kBoolLiteral_Kind: 633 return this->writeBoolLiteral((BoolLiteral&) expr); 634 case Expression::kConstructor_Kind: 635 return this->writeConstructor((Constructor&) expr, out); 636 case Expression::kIntLiteral_Kind: 637 return this->writeIntLiteral((IntLiteral&) expr); 638 case Expression::kFieldAccess_Kind: 639 return this->writeFieldAccess(((FieldAccess&) expr), out); 640 case Expression::kFloatLiteral_Kind: 641 return this->writeFloatLiteral(((FloatLiteral&) expr)); 642 case Expression::kFunctionCall_Kind: 643 return this->writeFunctionCall((FunctionCall&) expr, out); 644 case Expression::kPrefix_Kind: 645 return this->writePrefixExpression((PrefixExpression&) expr, out); 646 case Expression::kPostfix_Kind: 647 return this->writePostfixExpression((PostfixExpression&) expr, out); 648 case Expression::kSwizzle_Kind: 649 return this->writeSwizzle((Swizzle&) expr, out); 650 case Expression::kVariableReference_Kind: 651 return this->writeVariableReference((VariableReference&) expr, out); 652 case Expression::kTernary_Kind: 653 return this->writeTernaryExpression((TernaryExpression&) expr, out); 654 case Expression::kIndex_Kind: 655 return this->writeIndexExpression((IndexExpression&) expr, out); 656 default: 657 ABORT("unsupported expression: %s", expr.description().c_str()); 658 } 659 return -1; 660 } 661 662 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, OutputStream& out) { 663 auto intrinsic = fIntrinsicMap.find(c.fFunction.fName); 664 ASSERT(intrinsic != fIntrinsicMap.end()); 665 int32_t intrinsicId; 666 if (c.fArguments.size() > 0) { 667 const Type& type = c.fArguments[0]->fType; 668 if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) { 669 intrinsicId = std::get<1>(intrinsic->second); 670 } else if (is_signed(fContext, type)) { 671 intrinsicId = std::get<2>(intrinsic->second); 672 } else if (is_unsigned(fContext, type)) { 673 intrinsicId = std::get<3>(intrinsic->second); 674 } else if (is_bool(fContext, type)) { 675 intrinsicId = std::get<4>(intrinsic->second); 676 } else { 677 intrinsicId = std::get<1>(intrinsic->second); 678 } 679 } else { 680 intrinsicId = std::get<1>(intrinsic->second); 681 } 682 switch (std::get<0>(intrinsic->second)) { 683 case kGLSL_STD_450_IntrinsicKind: { 684 SpvId result = this->nextId(); 685 std::vector<SpvId> arguments; 686 for (size_t i = 0; i < c.fArguments.size(); i++) { 687 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 688 } 689 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); 690 this->writeWord(this->getType(c.fType), out); 691 this->writeWord(result, out); 692 this->writeWord(fGLSLExtendedInstructions, out); 693 this->writeWord(intrinsicId, out); 694 for (SpvId id : arguments) { 695 this->writeWord(id, out); 696 } 697 return result; 698 } 699 case kSPIRV_IntrinsicKind: { 700 SpvId result = this->nextId(); 701 std::vector<SpvId> arguments; 702 for (size_t i = 0; i < c.fArguments.size(); i++) { 703 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 704 } 705 if (c.fType != *fContext.fVoid_Type) { 706 this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out); 707 this->writeWord(this->getType(c.fType), out); 708 this->writeWord(result, out); 709 } else { 710 this->writeOpCode((SpvOp_) intrinsicId, 1 + (int32_t) arguments.size(), out); 711 } 712 for (SpvId id : arguments) { 713 this->writeWord(id, out); 714 } 715 return result; 716 } 717 case kSpecial_IntrinsicKind: 718 return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out); 719 default: 720 ABORT("unsupported intrinsic kind"); 721 } 722 } 723 724 std::vector<SpvId> SPIRVCodeGenerator::vectorize( 725 const std::vector<std::unique_ptr<Expression>>& args, 726 OutputStream& out) { 727 int vectorSize = 0; 728 for (const auto& a : args) { 729 if (a->fType.kind() == Type::kVector_Kind) { 730 if (vectorSize) { 731 ASSERT(a->fType.columns() == vectorSize); 732 } 733 else { 734 vectorSize = a->fType.columns(); 735 } 736 } 737 } 738 std::vector<SpvId> result; 739 for (const auto& a : args) { 740 SpvId raw = this->writeExpression(*a, out); 741 if (vectorSize && a->fType.kind() == Type::kScalar_Kind) { 742 SpvId vector = this->nextId(); 743 this->writeOpCode(SpvOpCompositeConstruct, 3 + vectorSize, out); 744 this->writeWord(this->getType(a->fType.toCompound(fContext, vectorSize, 1)), out); 745 this->writeWord(vector, out); 746 for (int i = 0; i < vectorSize; i++) { 747 this->writeWord(raw, out); 748 } 749 result.push_back(vector); 750 } else { 751 result.push_back(raw); 752 } 753 } 754 return result; 755 } 756 757 void SPIRVCodeGenerator::writeGLSLExtendedInstruction(const Type& type, SpvId id, SpvId floatInst, 758 SpvId signedInst, SpvId unsignedInst, 759 const std::vector<SpvId>& args, 760 OutputStream& out) { 761 this->writeOpCode(SpvOpExtInst, 5 + args.size(), out); 762 this->writeWord(this->getType(type), out); 763 this->writeWord(id, out); 764 this->writeWord(fGLSLExtendedInstructions, out); 765 766 if (is_float(fContext, type)) { 767 this->writeWord(floatInst, out); 768 } else if (is_signed(fContext, type)) { 769 this->writeWord(signedInst, out); 770 } else if (is_unsigned(fContext, type)) { 771 this->writeWord(unsignedInst, out); 772 } else { 773 ASSERT(false); 774 } 775 for (SpvId a : args) { 776 this->writeWord(a, out); 777 } 778 } 779 780 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind, 781 OutputStream& out) { 782 SpvId result = this->nextId(); 783 switch (kind) { 784 case kAtan_SpecialIntrinsic: { 785 std::vector<SpvId> arguments; 786 for (size_t i = 0; i < c.fArguments.size(); i++) { 787 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 788 } 789 this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out); 790 this->writeWord(this->getType(c.fType), out); 791 this->writeWord(result, out); 792 this->writeWord(fGLSLExtendedInstructions, out); 793 this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out); 794 for (SpvId id : arguments) { 795 this->writeWord(id, out); 796 } 797 break; 798 } 799 case kSubpassLoad_SpecialIntrinsic: { 800 SpvId img = this->writeExpression(*c.fArguments[0], out); 801 std::vector<std::unique_ptr<Expression>> args; 802 args.emplace_back(new FloatLiteral(fContext, -1, 0.0)); 803 args.emplace_back(new FloatLiteral(fContext, -1, 0.0)); 804 Constructor ctor(-1, *fContext.fFloat2_Type, std::move(args)); 805 SpvId coords = this->writeConstantVector(ctor); 806 if (1 == c.fArguments.size()) { 807 this->writeInstruction(SpvOpImageRead, 808 this->getType(c.fType), 809 result, 810 img, 811 coords, 812 out); 813 } else { 814 ASSERT(2 == c.fArguments.size()); 815 SpvId sample = this->writeExpression(*c.fArguments[1], out); 816 this->writeInstruction(SpvOpImageRead, 817 this->getType(c.fType), 818 result, 819 img, 820 coords, 821 SpvImageOperandsSampleMask, 822 sample, 823 out); 824 } 825 break; 826 } 827 case kTexelFetch_SpecialIntrinsic: { 828 ASSERT(c.fArguments.size() == 2); 829 SpvId image = this->nextId(); 830 this->writeInstruction(SpvOpImage, 831 this->getImageType(c.fArguments[0]->fType), 832 image, 833 this->writeExpression(*c.fArguments[0], out), 834 out); 835 this->writeInstruction(SpvOpImageFetch, 836 this->getType(c.fType), 837 result, 838 image, 839 this->writeExpression(*c.fArguments[1], out), 840 out); 841 break; 842 } 843 case kTexture_SpecialIntrinsic: { 844 SpvOp_ op = SpvOpImageSampleImplicitLod; 845 switch (c.fArguments[0]->fType.dimensions()) { 846 case SpvDim1D: 847 if (c.fArguments[1]->fType == *fContext.fFloat2_Type) { 848 op = SpvOpImageSampleProjImplicitLod; 849 } else { 850 ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type); 851 } 852 break; 853 case SpvDim2D: 854 if (c.fArguments[1]->fType == *fContext.fFloat3_Type) { 855 op = SpvOpImageSampleProjImplicitLod; 856 } else { 857 ASSERT(c.fArguments[1]->fType == *fContext.fFloat2_Type); 858 } 859 break; 860 case SpvDim3D: 861 if (c.fArguments[1]->fType == *fContext.fFloat4_Type) { 862 op = SpvOpImageSampleProjImplicitLod; 863 } else { 864 ASSERT(c.fArguments[1]->fType == *fContext.fFloat3_Type); 865 } 866 break; 867 case SpvDimCube: // fall through 868 case SpvDimRect: // fall through 869 case SpvDimBuffer: // fall through 870 case SpvDimSubpassData: 871 break; 872 } 873 SpvId type = this->getType(c.fType); 874 SpvId sampler = this->writeExpression(*c.fArguments[0], out); 875 SpvId uv = this->writeExpression(*c.fArguments[1], out); 876 if (c.fArguments.size() == 3) { 877 this->writeInstruction(op, type, result, sampler, uv, 878 SpvImageOperandsBiasMask, 879 this->writeExpression(*c.fArguments[2], out), 880 out); 881 } else { 882 ASSERT(c.fArguments.size() == 2); 883 this->writeInstruction(op, type, result, sampler, uv, 884 out); 885 } 886 break; 887 } 888 case kMod_SpecialIntrinsic: { 889 std::vector<SpvId> args = this->vectorize(c.fArguments, out); 890 ASSERT(args.size() == 2); 891 const Type& operandType = c.fArguments[0]->fType; 892 SpvOp_ op; 893 if (is_float(fContext, operandType)) { 894 op = SpvOpFMod; 895 } else if (is_signed(fContext, operandType)) { 896 op = SpvOpSMod; 897 } else if (is_unsigned(fContext, operandType)) { 898 op = SpvOpUMod; 899 } else { 900 ASSERT(false); 901 return 0; 902 } 903 this->writeOpCode(op, 5, out); 904 this->writeWord(this->getType(operandType), out); 905 this->writeWord(result, out); 906 this->writeWord(args[0], out); 907 this->writeWord(args[1], out); 908 break; 909 } 910 case kClamp_SpecialIntrinsic: { 911 std::vector<SpvId> args = this->vectorize(c.fArguments, out); 912 ASSERT(args.size() == 3); 913 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FClamp, GLSLstd450SClamp, 914 GLSLstd450UClamp, args, out); 915 break; 916 } 917 case kMax_SpecialIntrinsic: { 918 std::vector<SpvId> args = this->vectorize(c.fArguments, out); 919 ASSERT(args.size() == 2); 920 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMax, GLSLstd450SMax, 921 GLSLstd450UMax, args, out); 922 break; 923 } 924 case kMin_SpecialIntrinsic: { 925 std::vector<SpvId> args = this->vectorize(c.fArguments, out); 926 ASSERT(args.size() == 2); 927 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMin, GLSLstd450SMin, 928 GLSLstd450UMin, args, out); 929 break; 930 } 931 case kMix_SpecialIntrinsic: { 932 std::vector<SpvId> args = this->vectorize(c.fArguments, out); 933 ASSERT(args.size() == 3); 934 this->writeGLSLExtendedInstruction(c.fType, result, GLSLstd450FMix, SpvOpUndef, 935 SpvOpUndef, args, out); 936 break; 937 } 938 } 939 return result; 940 } 941 942 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, OutputStream& out) { 943 const auto& entry = fFunctionMap.find(&c.fFunction); 944 if (entry == fFunctionMap.end()) { 945 return this->writeIntrinsicCall(c, out); 946 } 947 // stores (variable, type, lvalue) pairs to extract and save after the function call is complete 948 std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues; 949 std::vector<SpvId> arguments; 950 for (size_t i = 0; i < c.fArguments.size(); i++) { 951 // id of temporary variable that we will use to hold this argument, or 0 if it is being 952 // passed directly 953 SpvId tmpVar; 954 // if we need a temporary var to store this argument, this is the value to store in the var 955 SpvId tmpValueId; 956 if (is_out(*c.fFunction.fParameters[i])) { 957 std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out); 958 SpvId ptr = lv->getPointer(); 959 if (ptr) { 960 arguments.push_back(ptr); 961 continue; 962 } else { 963 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to 964 // copy it into a temp, call the function, read the value out of the temp, and then 965 // update the lvalue. 966 tmpValueId = lv->load(out); 967 tmpVar = this->nextId(); 968 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType), 969 std::move(lv))); 970 } 971 } else { 972 // see getFunctionType for an explanation of why we're always using pointer parameters 973 tmpValueId = this->writeExpression(*c.fArguments[i], out); 974 tmpVar = this->nextId(); 975 } 976 this->writeInstruction(SpvOpVariable, 977 this->getPointerType(c.fArguments[i]->fType, 978 SpvStorageClassFunction), 979 tmpVar, 980 SpvStorageClassFunction, 981 fVariableBuffer); 982 this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out); 983 arguments.push_back(tmpVar); 984 } 985 SpvId result = this->nextId(); 986 this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out); 987 this->writeWord(this->getType(c.fType), out); 988 this->writeWord(result, out); 989 this->writeWord(entry->second, out); 990 for (SpvId id : arguments) { 991 this->writeWord(id, out); 992 } 993 // now that the call is complete, we may need to update some lvalues with the new values of out 994 // arguments 995 for (const auto& tuple : lvalues) { 996 SpvId load = this->nextId(); 997 this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out); 998 std::get<2>(tuple)->store(load, out); 999 } 1000 return result; 1001 } 1002 1003 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) { 1004 ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant()); 1005 SpvId result = this->nextId(); 1006 std::vector<SpvId> arguments; 1007 for (size_t i = 0; i < c.fArguments.size(); i++) { 1008 arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer)); 1009 } 1010 SpvId type = this->getType(c.fType); 1011 if (c.fArguments.size() == 1) { 1012 // with a single argument, a vector will have all of its entries equal to the argument 1013 this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer); 1014 this->writeWord(type, fConstantBuffer); 1015 this->writeWord(result, fConstantBuffer); 1016 for (int i = 0; i < c.fType.columns(); i++) { 1017 this->writeWord(arguments[0], fConstantBuffer); 1018 } 1019 } else { 1020 this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(), 1021 fConstantBuffer); 1022 this->writeWord(type, fConstantBuffer); 1023 this->writeWord(result, fConstantBuffer); 1024 for (SpvId id : arguments) { 1025 this->writeWord(id, fConstantBuffer); 1026 } 1027 } 1028 return result; 1029 } 1030 1031 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, OutputStream& out) { 1032 ASSERT(c.fType.isFloat()); 1033 ASSERT(c.fArguments.size() == 1); 1034 ASSERT(c.fArguments[0]->fType.isNumber()); 1035 SpvId result = this->nextId(); 1036 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 1037 if (c.fArguments[0]->fType.isSigned()) { 1038 this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter, 1039 out); 1040 } else { 1041 ASSERT(c.fArguments[0]->fType.isUnsigned()); 1042 this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter, 1043 out); 1044 } 1045 return result; 1046 } 1047 1048 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, OutputStream& out) { 1049 ASSERT(c.fType.isSigned()); 1050 ASSERT(c.fArguments.size() == 1); 1051 ASSERT(c.fArguments[0]->fType.isNumber()); 1052 SpvId result = this->nextId(); 1053 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 1054 if (c.fArguments[0]->fType.isFloat()) { 1055 this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter, 1056 out); 1057 } 1058 else { 1059 ASSERT(c.fArguments[0]->fType.isUnsigned()); 1060 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter, 1061 out); 1062 } 1063 return result; 1064 } 1065 1066 SpvId SPIRVCodeGenerator::writeUIntConstructor(const Constructor& c, OutputStream& out) { 1067 ASSERT(c.fType.isUnsigned()); 1068 ASSERT(c.fArguments.size() == 1); 1069 ASSERT(c.fArguments[0]->fType.isNumber()); 1070 SpvId result = this->nextId(); 1071 SpvId parameter = this->writeExpression(*c.fArguments[0], out); 1072 if (c.fArguments[0]->fType.isFloat()) { 1073 this->writeInstruction(SpvOpConvertFToU, this->getType(c.fType), result, parameter, 1074 out); 1075 } else { 1076 ASSERT(c.fArguments[0]->fType.isSigned()); 1077 this->writeInstruction(SpvOpBitcast, this->getType(c.fType), result, parameter, 1078 out); 1079 } 1080 return result; 1081 } 1082 1083 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type, 1084 OutputStream& out) { 1085 FloatLiteral zero(fContext, -1, 0); 1086 SpvId zeroId = this->writeFloatLiteral(zero); 1087 std::vector<SpvId> columnIds; 1088 for (int column = 0; column < type.columns(); column++) { 1089 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(), 1090 out); 1091 this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)), 1092 out); 1093 SpvId columnId = this->nextId(); 1094 this->writeWord(columnId, out); 1095 columnIds.push_back(columnId); 1096 for (int row = 0; row < type.columns(); row++) { 1097 this->writeWord(row == column ? diagonal : zeroId, out); 1098 } 1099 } 1100 this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(), 1101 out); 1102 this->writeWord(this->getType(type), out); 1103 this->writeWord(id, out); 1104 for (SpvId id : columnIds) { 1105 this->writeWord(id, out); 1106 } 1107 } 1108 1109 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType, 1110 const Type& dstType, OutputStream& out) { 1111 ASSERT(srcType.kind() == Type::kMatrix_Kind); 1112 ASSERT(dstType.kind() == Type::kMatrix_Kind); 1113 ASSERT(srcType.componentType() == dstType.componentType()); 1114 SpvId srcColumnType = this->getType(srcType.componentType().toCompound(fContext, 1115 srcType.rows(), 1116 1)); 1117 SpvId dstColumnType = this->getType(dstType.componentType().toCompound(fContext, 1118 dstType.rows(), 1119 1)); 1120 SpvId zeroId; 1121 if (dstType.componentType() == *fContext.fFloat_Type) { 1122 FloatLiteral zero(fContext, -1, 0.0); 1123 zeroId = this->writeFloatLiteral(zero); 1124 } else if (dstType.componentType() == *fContext.fInt_Type) { 1125 IntLiteral zero(fContext, -1, 0); 1126 zeroId = this->writeIntLiteral(zero); 1127 } else { 1128 ABORT("unsupported matrix component type"); 1129 } 1130 SpvId zeroColumn = 0; 1131 SpvId columns[4]; 1132 for (int i = 0; i < dstType.columns(); i++) { 1133 if (i < srcType.columns()) { 1134 // we're still inside the src matrix, copy the column 1135 SpvId srcColumn = this->nextId(); 1136 this->writeInstruction(SpvOpCompositeExtract, srcColumnType, srcColumn, src, i, out); 1137 SpvId dstColumn; 1138 if (srcType.rows() == dstType.rows()) { 1139 // columns are equal size, don't need to do anything 1140 dstColumn = srcColumn; 1141 } 1142 else if (dstType.rows() > srcType.rows()) { 1143 // dst column is bigger, need to zero-pad it 1144 dstColumn = this->nextId(); 1145 int delta = dstType.rows() - srcType.rows(); 1146 this->writeOpCode(SpvOpCompositeConstruct, 4 + delta, out); 1147 this->writeWord(dstColumnType, out); 1148 this->writeWord(dstColumn, out); 1149 this->writeWord(srcColumn, out); 1150 for (int i = 0; i < delta; ++i) { 1151 this->writeWord(zeroId, out); 1152 } 1153 } 1154 else { 1155 // dst column is smaller, need to swizzle the src column 1156 dstColumn = this->nextId(); 1157 int count = dstType.rows(); 1158 this->writeOpCode(SpvOpVectorShuffle, 5 + count, out); 1159 this->writeWord(dstColumnType, out); 1160 this->writeWord(dstColumn, out); 1161 this->writeWord(srcColumn, out); 1162 this->writeWord(srcColumn, out); 1163 for (int i = 0; i < count; i++) { 1164 this->writeWord(i, out); 1165 } 1166 } 1167 columns[i] = dstColumn; 1168 } else { 1169 // we're past the end of the src matrix, need a vector of zeroes 1170 if (!zeroColumn) { 1171 zeroColumn = this->nextId(); 1172 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.rows(), out); 1173 this->writeWord(dstColumnType, out); 1174 this->writeWord(zeroColumn, out); 1175 for (int i = 0; i < dstType.rows(); ++i) { 1176 this->writeWord(zeroId, out); 1177 } 1178 } 1179 columns[i] = zeroColumn; 1180 } 1181 } 1182 this->writeOpCode(SpvOpCompositeConstruct, 3 + dstType.columns(), out); 1183 this->writeWord(this->getType(dstType), out); 1184 this->writeWord(id, out); 1185 for (int i = 0; i < dstType.columns(); i++) { 1186 this->writeWord(columns[i], out); 1187 } 1188 } 1189 1190 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, OutputStream& out) { 1191 ASSERT(c.fType.kind() == Type::kMatrix_Kind); 1192 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1193 // an instruction 1194 std::vector<SpvId> arguments; 1195 for (size_t i = 0; i < c.fArguments.size(); i++) { 1196 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1197 } 1198 SpvId result = this->nextId(); 1199 int rows = c.fType.rows(); 1200 int columns = c.fType.columns(); 1201 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { 1202 this->writeUniformScaleMatrix(result, arguments[0], c.fType, out); 1203 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) { 1204 this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out); 1205 } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kVector_Kind) { 1206 ASSERT(c.fType.rows() == 2 && c.fType.columns() == 2); 1207 ASSERT(c.fArguments[0]->fType.columns() == 4); 1208 SpvId componentType = this->getType(c.fType.componentType()); 1209 SpvId v[4]; 1210 for (int i = 0; i < 4; ++i) { 1211 v[i] = this->nextId(); 1212 this->writeInstruction(SpvOpCompositeExtract, componentType, v[i], arguments[0], i, out); 1213 } 1214 SpvId columnType = this->getType(c.fType.componentType().toCompound(fContext, 2, 1)); 1215 SpvId column1 = this->nextId(); 1216 this->writeInstruction(SpvOpCompositeConstruct, columnType, column1, v[0], v[1], out); 1217 SpvId column2 = this->nextId(); 1218 this->writeInstruction(SpvOpCompositeConstruct, columnType, column2, v[2], v[3], out); 1219 this->writeInstruction(SpvOpCompositeConstruct, this->getType(c.fType), result, column1, 1220 column2, out); 1221 } else { 1222 std::vector<SpvId> columnIds; 1223 // ids of vectors and scalars we have written to the current column so far 1224 std::vector<SpvId> currentColumn; 1225 // the total number of scalars represented by currentColumn's entries 1226 int currentCount = 0; 1227 for (size_t i = 0; i < arguments.size(); i++) { 1228 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind && 1229 c.fArguments[i]->fType.columns() == c.fType.rows()) { 1230 // this is a complete column by itself 1231 ASSERT(currentCount == 0); 1232 columnIds.push_back(arguments[i]); 1233 } else { 1234 currentColumn.push_back(arguments[i]); 1235 currentCount += c.fArguments[i]->fType.columns(); 1236 if (currentCount == rows) { 1237 currentCount = 0; 1238 this->writeOpCode(SpvOpCompositeConstruct, 3 + currentColumn.size(), out); 1239 this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows, 1240 1)), 1241 out); 1242 SpvId columnId = this->nextId(); 1243 this->writeWord(columnId, out); 1244 columnIds.push_back(columnId); 1245 for (SpvId id : currentColumn) { 1246 this->writeWord(id, out); 1247 } 1248 currentColumn.clear(); 1249 } 1250 ASSERT(currentCount < rows); 1251 } 1252 } 1253 ASSERT(columnIds.size() == (size_t) columns); 1254 this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out); 1255 this->writeWord(this->getType(c.fType), out); 1256 this->writeWord(result, out); 1257 for (SpvId id : columnIds) { 1258 this->writeWord(id, out); 1259 } 1260 } 1261 return result; 1262 } 1263 1264 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, OutputStream& out) { 1265 ASSERT(c.fType.kind() == Type::kVector_Kind); 1266 if (c.isConstant()) { 1267 return this->writeConstantVector(c); 1268 } 1269 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1270 // an instruction 1271 std::vector<SpvId> arguments; 1272 for (size_t i = 0; i < c.fArguments.size(); i++) { 1273 if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) { 1274 // SPIR-V doesn't support vector(vector-of-different-type) directly, so we need to 1275 // extract the components and convert them in that case manually. On top of that, 1276 // as of this writing there's a bug in the Intel Vulkan driver where OpCreateComposite 1277 // doesn't handle vector arguments at all, so we always extract vector components and 1278 // pass them into OpCreateComposite individually. 1279 SpvId vec = this->writeExpression(*c.fArguments[i], out); 1280 SpvOp_ op = SpvOpUndef; 1281 const Type& src = c.fArguments[i]->fType.componentType(); 1282 const Type& dst = c.fType.componentType(); 1283 if (dst == *fContext.fFloat_Type || dst == *fContext.fHalf_Type) { 1284 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) { 1285 if (c.fArguments.size() == 1) { 1286 return vec; 1287 } 1288 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) { 1289 op = SpvOpConvertSToF; 1290 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) { 1291 op = SpvOpConvertUToF; 1292 } else { 1293 ASSERT(false); 1294 } 1295 } else if (dst == *fContext.fInt_Type || dst == *fContext.fShort_Type) { 1296 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) { 1297 op = SpvOpConvertFToS; 1298 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) { 1299 if (c.fArguments.size() == 1) { 1300 return vec; 1301 } 1302 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) { 1303 op = SpvOpBitcast; 1304 } else { 1305 ASSERT(false); 1306 } 1307 } else if (dst == *fContext.fUInt_Type || dst == *fContext.fUShort_Type) { 1308 if (src == *fContext.fFloat_Type || src == *fContext.fHalf_Type) { 1309 op = SpvOpConvertFToS; 1310 } else if (src == *fContext.fInt_Type || src == *fContext.fShort_Type) { 1311 op = SpvOpBitcast; 1312 } else if (src == *fContext.fUInt_Type || src == *fContext.fUShort_Type) { 1313 if (c.fArguments.size() == 1) { 1314 return vec; 1315 } 1316 } else { 1317 ASSERT(false); 1318 } 1319 } 1320 for (int j = 0; j < c.fArguments[i]->fType.columns(); j++) { 1321 SpvId swizzle = this->nextId(); 1322 this->writeInstruction(SpvOpCompositeExtract, this->getType(src), swizzle, vec, j, 1323 out); 1324 if (op != SpvOpUndef) { 1325 SpvId cast = this->nextId(); 1326 this->writeInstruction(op, this->getType(dst), cast, swizzle, out); 1327 arguments.push_back(cast); 1328 } else { 1329 arguments.push_back(swizzle); 1330 } 1331 } 1332 } else { 1333 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1334 } 1335 } 1336 SpvId result = this->nextId(); 1337 if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) { 1338 this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out); 1339 this->writeWord(this->getType(c.fType), out); 1340 this->writeWord(result, out); 1341 for (int i = 0; i < c.fType.columns(); i++) { 1342 this->writeWord(arguments[0], out); 1343 } 1344 } else { 1345 ASSERT(arguments.size() > 1); 1346 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) arguments.size(), out); 1347 this->writeWord(this->getType(c.fType), out); 1348 this->writeWord(result, out); 1349 for (SpvId id : arguments) { 1350 this->writeWord(id, out); 1351 } 1352 } 1353 return result; 1354 } 1355 1356 SpvId SPIRVCodeGenerator::writeArrayConstructor(const Constructor& c, OutputStream& out) { 1357 ASSERT(c.fType.kind() == Type::kArray_Kind); 1358 // go ahead and write the arguments so we don't try to write new instructions in the middle of 1359 // an instruction 1360 std::vector<SpvId> arguments; 1361 for (size_t i = 0; i < c.fArguments.size(); i++) { 1362 arguments.push_back(this->writeExpression(*c.fArguments[i], out)); 1363 } 1364 SpvId result = this->nextId(); 1365 this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out); 1366 this->writeWord(this->getType(c.fType), out); 1367 this->writeWord(result, out); 1368 for (SpvId id : arguments) { 1369 this->writeWord(id, out); 1370 } 1371 return result; 1372 } 1373 1374 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, OutputStream& out) { 1375 if (c.fArguments.size() == 1 && 1376 this->getActualType(c.fType) == this->getActualType(c.fArguments[0]->fType)) { 1377 return this->writeExpression(*c.fArguments[0], out); 1378 } 1379 if (c.fType == *fContext.fFloat_Type || c.fType == *fContext.fHalf_Type) { 1380 return this->writeFloatConstructor(c, out); 1381 } else if (c.fType == *fContext.fInt_Type || c.fType == *fContext.fShort_Type) { 1382 return this->writeIntConstructor(c, out); 1383 } else if (c.fType == *fContext.fUInt_Type || c.fType == *fContext.fUShort_Type) { 1384 return this->writeUIntConstructor(c, out); 1385 } 1386 switch (c.fType.kind()) { 1387 case Type::kVector_Kind: 1388 return this->writeVectorConstructor(c, out); 1389 case Type::kMatrix_Kind: 1390 return this->writeMatrixConstructor(c, out); 1391 case Type::kArray_Kind: 1392 return this->writeArrayConstructor(c, out); 1393 default: 1394 ABORT("unsupported constructor: %s", c.description().c_str()); 1395 } 1396 } 1397 1398 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) { 1399 if (modifiers.fFlags & Modifiers::kIn_Flag) { 1400 ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag)); 1401 return SpvStorageClassInput; 1402 } else if (modifiers.fFlags & Modifiers::kOut_Flag) { 1403 ASSERT(!(modifiers.fLayout.fFlags & Layout::kPushConstant_Flag)); 1404 return SpvStorageClassOutput; 1405 } else if (modifiers.fFlags & Modifiers::kUniform_Flag) { 1406 if (modifiers.fLayout.fFlags & Layout::kPushConstant_Flag) { 1407 return SpvStorageClassPushConstant; 1408 } 1409 return SpvStorageClassUniform; 1410 } else { 1411 return SpvStorageClassFunction; 1412 } 1413 } 1414 1415 SpvStorageClass_ get_storage_class(const Expression& expr) { 1416 switch (expr.fKind) { 1417 case Expression::kVariableReference_Kind: { 1418 const Variable& var = ((VariableReference&) expr).fVariable; 1419 if (var.fStorage != Variable::kGlobal_Storage) { 1420 return SpvStorageClassFunction; 1421 } 1422 SpvStorageClass_ result = get_storage_class(var.fModifiers); 1423 if (result == SpvStorageClassFunction) { 1424 result = SpvStorageClassPrivate; 1425 } 1426 return result; 1427 } 1428 case Expression::kFieldAccess_Kind: 1429 return get_storage_class(*((FieldAccess&) expr).fBase); 1430 case Expression::kIndex_Kind: 1431 return get_storage_class(*((IndexExpression&) expr).fBase); 1432 default: 1433 return SpvStorageClassFunction; 1434 } 1435 } 1436 1437 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, OutputStream& out) { 1438 std::vector<SpvId> chain; 1439 switch (expr.fKind) { 1440 case Expression::kIndex_Kind: { 1441 IndexExpression& indexExpr = (IndexExpression&) expr; 1442 chain = this->getAccessChain(*indexExpr.fBase, out); 1443 chain.push_back(this->writeExpression(*indexExpr.fIndex, out)); 1444 break; 1445 } 1446 case Expression::kFieldAccess_Kind: { 1447 FieldAccess& fieldExpr = (FieldAccess&) expr; 1448 chain = this->getAccessChain(*fieldExpr.fBase, out); 1449 IntLiteral index(fContext, -1, fieldExpr.fFieldIndex); 1450 chain.push_back(this->writeIntLiteral(index)); 1451 break; 1452 } 1453 default: 1454 chain.push_back(this->getLValue(expr, out)->getPointer()); 1455 } 1456 return chain; 1457 } 1458 1459 class PointerLValue : public SPIRVCodeGenerator::LValue { 1460 public: 1461 PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type) 1462 : fGen(gen) 1463 , fPointer(pointer) 1464 , fType(type) {} 1465 1466 virtual SpvId getPointer() override { 1467 return fPointer; 1468 } 1469 1470 virtual SpvId load(OutputStream& out) override { 1471 SpvId result = fGen.nextId(); 1472 fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out); 1473 return result; 1474 } 1475 1476 virtual void store(SpvId value, OutputStream& out) override { 1477 fGen.writeInstruction(SpvOpStore, fPointer, value, out); 1478 } 1479 1480 private: 1481 SPIRVCodeGenerator& fGen; 1482 const SpvId fPointer; 1483 const SpvId fType; 1484 }; 1485 1486 class SwizzleLValue : public SPIRVCodeGenerator::LValue { 1487 public: 1488 SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components, 1489 const Type& baseType, const Type& swizzleType) 1490 : fGen(gen) 1491 , fVecPointer(vecPointer) 1492 , fComponents(components) 1493 , fBaseType(baseType) 1494 , fSwizzleType(swizzleType) {} 1495 1496 virtual SpvId getPointer() override { 1497 return 0; 1498 } 1499 1500 virtual SpvId load(OutputStream& out) override { 1501 SpvId base = fGen.nextId(); 1502 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out); 1503 SpvId result = fGen.nextId(); 1504 fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out); 1505 fGen.writeWord(fGen.getType(fSwizzleType), out); 1506 fGen.writeWord(result, out); 1507 fGen.writeWord(base, out); 1508 fGen.writeWord(base, out); 1509 for (int component : fComponents) { 1510 fGen.writeWord(component, out); 1511 } 1512 return result; 1513 } 1514 1515 virtual void store(SpvId value, OutputStream& out) override { 1516 // use OpVectorShuffle to mix and match the vector components. We effectively create 1517 // a virtual vector out of the concatenation of the left and right vectors, and then 1518 // select components from this virtual vector to make the result vector. For 1519 // instance, given: 1520 // float3L = ...; 1521 // float3R = ...; 1522 // L.xz = R.xy; 1523 // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want 1524 // our result vector to look like (R.x, L.y, R.y), so we need to select indices 1525 // (3, 1, 4). 1526 SpvId base = fGen.nextId(); 1527 fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out); 1528 SpvId shuffle = fGen.nextId(); 1529 fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out); 1530 fGen.writeWord(fGen.getType(fBaseType), out); 1531 fGen.writeWord(shuffle, out); 1532 fGen.writeWord(base, out); 1533 fGen.writeWord(value, out); 1534 for (int i = 0; i < fBaseType.columns(); i++) { 1535 // current offset into the virtual vector, defaults to pulling the unmodified 1536 // value from the left side 1537 int offset = i; 1538 // check to see if we are writing this component 1539 for (size_t j = 0; j < fComponents.size(); j++) { 1540 if (fComponents[j] == i) { 1541 // we're writing to this component, so adjust the offset to pull from 1542 // the correct component of the right side instead of preserving the 1543 // value from the left 1544 offset = (int) (j + fBaseType.columns()); 1545 break; 1546 } 1547 } 1548 fGen.writeWord(offset, out); 1549 } 1550 fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out); 1551 } 1552 1553 private: 1554 SPIRVCodeGenerator& fGen; 1555 const SpvId fVecPointer; 1556 const std::vector<int>& fComponents; 1557 const Type& fBaseType; 1558 const Type& fSwizzleType; 1559 }; 1560 1561 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr, 1562 OutputStream& out) { 1563 switch (expr.fKind) { 1564 case Expression::kVariableReference_Kind: { 1565 const Variable& var = ((VariableReference&) expr).fVariable; 1566 auto entry = fVariableMap.find(&var); 1567 ASSERT(entry != fVariableMap.end()); 1568 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1569 *this, 1570 entry->second, 1571 this->getType(expr.fType))); 1572 } 1573 case Expression::kIndex_Kind: // fall through 1574 case Expression::kFieldAccess_Kind: { 1575 std::vector<SpvId> chain = this->getAccessChain(expr, out); 1576 SpvId member = this->nextId(); 1577 this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out); 1578 this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out); 1579 this->writeWord(member, out); 1580 for (SpvId idx : chain) { 1581 this->writeWord(idx, out); 1582 } 1583 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1584 *this, 1585 member, 1586 this->getType(expr.fType))); 1587 } 1588 case Expression::kSwizzle_Kind: { 1589 Swizzle& swizzle = (Swizzle&) expr; 1590 size_t count = swizzle.fComponents.size(); 1591 SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer(); 1592 ASSERT(base); 1593 if (count == 1) { 1594 IntLiteral index(fContext, -1, swizzle.fComponents[0]); 1595 SpvId member = this->nextId(); 1596 this->writeInstruction(SpvOpAccessChain, 1597 this->getPointerType(swizzle.fType, 1598 get_storage_class(*swizzle.fBase)), 1599 member, 1600 base, 1601 this->writeIntLiteral(index), 1602 out); 1603 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1604 *this, 1605 member, 1606 this->getType(expr.fType))); 1607 } else { 1608 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue( 1609 *this, 1610 base, 1611 swizzle.fComponents, 1612 swizzle.fBase->fType, 1613 expr.fType)); 1614 } 1615 } 1616 case Expression::kTernary_Kind: { 1617 TernaryExpression& t = (TernaryExpression&) expr; 1618 SpvId test = this->writeExpression(*t.fTest, out); 1619 SpvId end = this->nextId(); 1620 SpvId ifTrueLabel = this->nextId(); 1621 SpvId ifFalseLabel = this->nextId(); 1622 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 1623 this->writeInstruction(SpvOpBranchConditional, test, ifTrueLabel, ifFalseLabel, out); 1624 this->writeLabel(ifTrueLabel, out); 1625 SpvId ifTrue = this->getLValue(*t.fIfTrue, out)->getPointer(); 1626 ASSERT(ifTrue); 1627 this->writeInstruction(SpvOpBranch, end, out); 1628 ifTrueLabel = fCurrentBlock; 1629 SpvId ifFalse = this->getLValue(*t.fIfFalse, out)->getPointer(); 1630 ASSERT(ifFalse); 1631 ifFalseLabel = fCurrentBlock; 1632 this->writeInstruction(SpvOpBranch, end, out); 1633 SpvId result = this->nextId(); 1634 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, ifTrue, 1635 ifTrueLabel, ifFalse, ifFalseLabel, out); 1636 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1637 *this, 1638 result, 1639 this->getType(expr.fType))); 1640 } 1641 default: 1642 // expr isn't actually an lvalue, create a dummy variable for it. This case happens due 1643 // to the need to store values in temporary variables during function calls (see 1644 // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been 1645 // caught by IRGenerator 1646 SpvId result = this->nextId(); 1647 SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction); 1648 this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction, 1649 fVariableBuffer); 1650 this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out); 1651 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue( 1652 *this, 1653 result, 1654 this->getType(expr.fType))); 1655 } 1656 } 1657 1658 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, OutputStream& out) { 1659 SpvId result = this->nextId(); 1660 auto entry = fVariableMap.find(&ref.fVariable); 1661 ASSERT(entry != fVariableMap.end()); 1662 SpvId var = entry->second; 1663 this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out); 1664 if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN && 1665 fProgram.fSettings.fFlipY) { 1666 // need to remap to a top-left coordinate system 1667 if (fRTHeightStructId == (SpvId) -1) { 1668 // height variable hasn't been written yet 1669 std::shared_ptr<SymbolTable> st(new SymbolTable(&fErrors)); 1670 ASSERT(fRTHeightFieldIndex == (SpvId) -1); 1671 std::vector<Type::Field> fields; 1672 fields.emplace_back(Modifiers(), SKSL_RTHEIGHT_NAME, fContext.fFloat_Type.get()); 1673 StringFragment name("sksl_synthetic_uniforms"); 1674 Type intfStruct(-1, name, fields); 1675 Layout layout(0, -1, -1, 1, -1, -1, -1, -1, Layout::Format::kUnspecified, 1676 Layout::kUnspecified_Primitive, -1, -1, "", Layout::kNo_Key, 1677 StringFragment()); 1678 Variable* intfVar = new Variable(-1, 1679 Modifiers(layout, Modifiers::kUniform_Flag), 1680 name, 1681 intfStruct, 1682 Variable::kGlobal_Storage); 1683 fSynthetics.takeOwnership(intfVar); 1684 InterfaceBlock intf(-1, intfVar, name, String(""), 1685 std::vector<std::unique_ptr<Expression>>(), st); 1686 fRTHeightStructId = this->writeInterfaceBlock(intf); 1687 fRTHeightFieldIndex = 0; 1688 } 1689 ASSERT(fRTHeightFieldIndex != (SpvId) -1); 1690 // write float4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0) 1691 SpvId xId = this->nextId(); 1692 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId, 1693 result, 0, out); 1694 IntLiteral fieldIndex(fContext, -1, fRTHeightFieldIndex); 1695 SpvId fieldIndexId = this->writeIntLiteral(fieldIndex); 1696 SpvId heightPtr = this->nextId(); 1697 this->writeOpCode(SpvOpAccessChain, 5, out); 1698 this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out); 1699 this->writeWord(heightPtr, out); 1700 this->writeWord(fRTHeightStructId, out); 1701 this->writeWord(fieldIndexId, out); 1702 SpvId heightRead = this->nextId(); 1703 this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead, 1704 heightPtr, out); 1705 SpvId rawYId = this->nextId(); 1706 this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId, 1707 result, 1, out); 1708 SpvId flippedYId = this->nextId(); 1709 this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId, 1710 heightRead, rawYId, out); 1711 FloatLiteral zero(fContext, -1, 0.0); 1712 SpvId zeroId = writeFloatLiteral(zero); 1713 FloatLiteral one(fContext, -1, 1.0); 1714 SpvId oneId = writeFloatLiteral(one); 1715 SpvId flipped = this->nextId(); 1716 this->writeOpCode(SpvOpCompositeConstruct, 7, out); 1717 this->writeWord(this->getType(*fContext.fFloat4_Type), out); 1718 this->writeWord(flipped, out); 1719 this->writeWord(xId, out); 1720 this->writeWord(flippedYId, out); 1721 this->writeWord(zeroId, out); 1722 this->writeWord(oneId, out); 1723 return flipped; 1724 } 1725 return result; 1726 } 1727 1728 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, OutputStream& out) { 1729 return getLValue(expr, out)->load(out); 1730 } 1731 1732 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, OutputStream& out) { 1733 return getLValue(f, out)->load(out); 1734 } 1735 1736 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, OutputStream& out) { 1737 SpvId base = this->writeExpression(*swizzle.fBase, out); 1738 SpvId result = this->nextId(); 1739 size_t count = swizzle.fComponents.size(); 1740 if (count == 1) { 1741 this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base, 1742 swizzle.fComponents[0], out); 1743 } else { 1744 this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out); 1745 this->writeWord(this->getType(swizzle.fType), out); 1746 this->writeWord(result, out); 1747 this->writeWord(base, out); 1748 this->writeWord(base, out); 1749 for (int component : swizzle.fComponents) { 1750 this->writeWord(component, out); 1751 } 1752 } 1753 return result; 1754 } 1755 1756 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType, 1757 const Type& operandType, SpvId lhs, 1758 SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt, 1759 SpvOp_ ifUInt, SpvOp_ ifBool, OutputStream& out) { 1760 SpvId result = this->nextId(); 1761 if (is_float(fContext, operandType)) { 1762 this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out); 1763 } else if (is_signed(fContext, operandType)) { 1764 this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out); 1765 } else if (is_unsigned(fContext, operandType)) { 1766 this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out); 1767 } else if (operandType == *fContext.fBool_Type) { 1768 this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out); 1769 } else { 1770 ABORT("invalid operandType: %s", operandType.description().c_str()); 1771 } 1772 return result; 1773 } 1774 1775 bool is_assignment(Token::Kind op) { 1776 switch (op) { 1777 case Token::EQ: // fall through 1778 case Token::PLUSEQ: // fall through 1779 case Token::MINUSEQ: // fall through 1780 case Token::STAREQ: // fall through 1781 case Token::SLASHEQ: // fall through 1782 case Token::PERCENTEQ: // fall through 1783 case Token::SHLEQ: // fall through 1784 case Token::SHREQ: // fall through 1785 case Token::BITWISEOREQ: // fall through 1786 case Token::BITWISEXOREQ: // fall through 1787 case Token::BITWISEANDEQ: // fall through 1788 case Token::LOGICALOREQ: // fall through 1789 case Token::LOGICALXOREQ: // fall through 1790 case Token::LOGICALANDEQ: 1791 return true; 1792 default: 1793 return false; 1794 } 1795 } 1796 1797 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, OutputStream& out) { 1798 if (operandType.kind() == Type::kVector_Kind) { 1799 SpvId result = this->nextId(); 1800 this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out); 1801 return result; 1802 } 1803 return id; 1804 } 1805 1806 SpvId SPIRVCodeGenerator::writeMatrixComparison(const Type& operandType, SpvId lhs, SpvId rhs, 1807 SpvOp_ floatOperator, SpvOp_ intOperator, 1808 OutputStream& out) { 1809 SpvOp_ compareOp = is_float(fContext, operandType) ? floatOperator : intOperator; 1810 ASSERT(operandType.kind() == Type::kMatrix_Kind); 1811 SpvId rowType = this->getType(operandType.componentType().toCompound(fContext, 1812 operandType.columns(), 1813 1)); 1814 SpvId bvecType = this->getType(fContext.fBool_Type->toCompound(fContext, 1815 operandType.columns(), 1816 1)); 1817 SpvId boolType = this->getType(*fContext.fBool_Type); 1818 SpvId result = 0; 1819 for (int i = 0; i < operandType.rows(); i++) { 1820 SpvId rowL = this->nextId(); 1821 this->writeInstruction(SpvOpCompositeExtract, rowType, rowL, lhs, 0, out); 1822 SpvId rowR = this->nextId(); 1823 this->writeInstruction(SpvOpCompositeExtract, rowType, rowR, rhs, 0, out); 1824 SpvId compare = this->nextId(); 1825 this->writeInstruction(compareOp, bvecType, compare, rowL, rowR, out); 1826 SpvId all = this->nextId(); 1827 this->writeInstruction(SpvOpAll, boolType, all, compare, out); 1828 if (result != 0) { 1829 SpvId next = this->nextId(); 1830 this->writeInstruction(SpvOpLogicalAnd, boolType, next, result, all, out); 1831 result = next; 1832 } 1833 else { 1834 result = all; 1835 } 1836 } 1837 return result; 1838 } 1839 1840 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, OutputStream& out) { 1841 // handle cases where we don't necessarily evaluate both LHS and RHS 1842 switch (b.fOperator) { 1843 case Token::EQ: { 1844 SpvId rhs = this->writeExpression(*b.fRight, out); 1845 this->getLValue(*b.fLeft, out)->store(rhs, out); 1846 return rhs; 1847 } 1848 case Token::LOGICALAND: 1849 return this->writeLogicalAnd(b, out); 1850 case Token::LOGICALOR: 1851 return this->writeLogicalOr(b, out); 1852 default: 1853 break; 1854 } 1855 1856 // "normal" operators 1857 const Type& resultType = b.fType; 1858 std::unique_ptr<LValue> lvalue; 1859 SpvId lhs; 1860 if (is_assignment(b.fOperator)) { 1861 lvalue = this->getLValue(*b.fLeft, out); 1862 lhs = lvalue->load(out); 1863 } else { 1864 lvalue = nullptr; 1865 lhs = this->writeExpression(*b.fLeft, out); 1866 } 1867 SpvId rhs = this->writeExpression(*b.fRight, out); 1868 if (b.fOperator == Token::COMMA) { 1869 return rhs; 1870 } 1871 Type tmp("<invalid>"); 1872 // component type we are operating on: float, int, uint 1873 const Type* operandType; 1874 // IR allows mismatched types in expressions (e.g. float2* float), but they need special handling 1875 // in SPIR-V 1876 if (this->getActualType(b.fLeft->fType) != this->getActualType(b.fRight->fType)) { 1877 if (b.fLeft->fType.kind() == Type::kVector_Kind && 1878 b.fRight->fType.isNumber()) { 1879 // promote number to vector 1880 SpvId vec = this->nextId(); 1881 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); 1882 this->writeWord(this->getType(resultType), out); 1883 this->writeWord(vec, out); 1884 for (int i = 0; i < resultType.columns(); i++) { 1885 this->writeWord(rhs, out); 1886 } 1887 rhs = vec; 1888 operandType = &b.fRight->fType; 1889 } else if (b.fRight->fType.kind() == Type::kVector_Kind && 1890 b.fLeft->fType.isNumber()) { 1891 // promote number to vector 1892 SpvId vec = this->nextId(); 1893 this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out); 1894 this->writeWord(this->getType(resultType), out); 1895 this->writeWord(vec, out); 1896 for (int i = 0; i < resultType.columns(); i++) { 1897 this->writeWord(lhs, out); 1898 } 1899 lhs = vec; 1900 ASSERT(!lvalue); 1901 operandType = &b.fLeft->fType; 1902 } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) { 1903 SpvOp_ op; 1904 if (b.fRight->fType.kind() == Type::kMatrix_Kind) { 1905 op = SpvOpMatrixTimesMatrix; 1906 } else if (b.fRight->fType.kind() == Type::kVector_Kind) { 1907 op = SpvOpMatrixTimesVector; 1908 } else { 1909 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind); 1910 op = SpvOpMatrixTimesScalar; 1911 } 1912 SpvId result = this->nextId(); 1913 this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out); 1914 if (b.fOperator == Token::STAREQ) { 1915 lvalue->store(result, out); 1916 } else { 1917 ASSERT(b.fOperator == Token::STAR); 1918 } 1919 return result; 1920 } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) { 1921 SpvId result = this->nextId(); 1922 if (b.fLeft->fType.kind() == Type::kVector_Kind) { 1923 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result, 1924 lhs, rhs, out); 1925 } else { 1926 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind); 1927 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs, 1928 lhs, out); 1929 } 1930 if (b.fOperator == Token::STAREQ) { 1931 lvalue->store(result, out); 1932 } else { 1933 ASSERT(b.fOperator == Token::STAR); 1934 } 1935 return result; 1936 } else { 1937 ABORT("unsupported binary expression: %s", b.description().c_str()); 1938 } 1939 } else { 1940 tmp = this->getActualType(b.fLeft->fType); 1941 operandType = &tmp; 1942 ASSERT(*operandType == this->getActualType(b.fRight->fType)); 1943 } 1944 switch (b.fOperator) { 1945 case Token::EQEQ: { 1946 if (operandType->kind() == Type::kMatrix_Kind) { 1947 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdEqual, 1948 SpvOpIEqual, out); 1949 } 1950 ASSERT(resultType == *fContext.fBool_Type); 1951 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1952 SpvOpFOrdEqual, SpvOpIEqual, 1953 SpvOpIEqual, SpvOpLogicalEqual, out), 1954 *operandType, out); 1955 } 1956 case Token::NEQ: 1957 if (operandType->kind() == Type::kMatrix_Kind) { 1958 return this->writeMatrixComparison(*operandType, lhs, rhs, SpvOpFOrdNotEqual, 1959 SpvOpINotEqual, out); 1960 } 1961 ASSERT(resultType == *fContext.fBool_Type); 1962 return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1963 SpvOpFOrdNotEqual, SpvOpINotEqual, 1964 SpvOpINotEqual, SpvOpLogicalNotEqual, 1965 out), 1966 *operandType, out); 1967 case Token::GT: 1968 ASSERT(resultType == *fContext.fBool_Type); 1969 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1970 SpvOpFOrdGreaterThan, SpvOpSGreaterThan, 1971 SpvOpUGreaterThan, SpvOpUndef, out); 1972 case Token::LT: 1973 ASSERT(resultType == *fContext.fBool_Type); 1974 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan, 1975 SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out); 1976 case Token::GTEQ: 1977 ASSERT(resultType == *fContext.fBool_Type); 1978 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1979 SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual, 1980 SpvOpUGreaterThanEqual, SpvOpUndef, out); 1981 case Token::LTEQ: 1982 ASSERT(resultType == *fContext.fBool_Type); 1983 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 1984 SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual, 1985 SpvOpULessThanEqual, SpvOpUndef, out); 1986 case Token::PLUS: 1987 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, 1988 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 1989 case Token::MINUS: 1990 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, 1991 SpvOpISub, SpvOpISub, SpvOpUndef, out); 1992 case Token::STAR: 1993 if (b.fLeft->fType.kind() == Type::kMatrix_Kind && 1994 b.fRight->fType.kind() == Type::kMatrix_Kind) { 1995 // matrix multiply 1996 SpvId result = this->nextId(); 1997 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, 1998 lhs, rhs, out); 1999 return result; 2000 } 2001 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul, 2002 SpvOpIMul, SpvOpIMul, SpvOpUndef, out); 2003 case Token::SLASH: 2004 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv, 2005 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out); 2006 case Token::PERCENT: 2007 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod, 2008 SpvOpSMod, SpvOpUMod, SpvOpUndef, out); 2009 case Token::SHL: 2010 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2011 SpvOpShiftLeftLogical, SpvOpShiftLeftLogical, 2012 SpvOpUndef, out); 2013 case Token::SHR: 2014 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2015 SpvOpShiftRightArithmetic, SpvOpShiftRightLogical, 2016 SpvOpUndef, out); 2017 case Token::BITWISEAND: 2018 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2019 SpvOpBitwiseAnd, SpvOpBitwiseAnd, SpvOpUndef, out); 2020 case Token::BITWISEOR: 2021 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2022 SpvOpBitwiseOr, SpvOpBitwiseOr, SpvOpUndef, out); 2023 case Token::BITWISEXOR: 2024 return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpUndef, 2025 SpvOpBitwiseXor, SpvOpBitwiseXor, SpvOpUndef, out); 2026 case Token::PLUSEQ: { 2027 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd, 2028 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 2029 ASSERT(lvalue); 2030 lvalue->store(result, out); 2031 return result; 2032 } 2033 case Token::MINUSEQ: { 2034 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub, 2035 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2036 ASSERT(lvalue); 2037 lvalue->store(result, out); 2038 return result; 2039 } 2040 case Token::STAREQ: { 2041 if (b.fLeft->fType.kind() == Type::kMatrix_Kind && 2042 b.fRight->fType.kind() == Type::kMatrix_Kind) { 2043 // matrix multiply 2044 SpvId result = this->nextId(); 2045 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result, 2046 lhs, rhs, out); 2047 ASSERT(lvalue); 2048 lvalue->store(result, out); 2049 return result; 2050 } 2051 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul, 2052 SpvOpIMul, SpvOpIMul, SpvOpUndef, out); 2053 ASSERT(lvalue); 2054 lvalue->store(result, out); 2055 return result; 2056 } 2057 case Token::SLASHEQ: { 2058 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv, 2059 SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out); 2060 ASSERT(lvalue); 2061 lvalue->store(result, out); 2062 return result; 2063 } 2064 case Token::PERCENTEQ: { 2065 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMod, 2066 SpvOpSMod, SpvOpUMod, SpvOpUndef, out); 2067 ASSERT(lvalue); 2068 lvalue->store(result, out); 2069 return result; 2070 } 2071 case Token::SHLEQ: { 2072 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2073 SpvOpUndef, SpvOpShiftLeftLogical, 2074 SpvOpShiftLeftLogical, SpvOpUndef, out); 2075 ASSERT(lvalue); 2076 lvalue->store(result, out); 2077 return result; 2078 } 2079 case Token::SHREQ: { 2080 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2081 SpvOpUndef, SpvOpShiftRightArithmetic, 2082 SpvOpShiftRightLogical, SpvOpUndef, out); 2083 ASSERT(lvalue); 2084 lvalue->store(result, out); 2085 return result; 2086 } 2087 case Token::BITWISEANDEQ: { 2088 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2089 SpvOpUndef, SpvOpBitwiseAnd, SpvOpBitwiseAnd, 2090 SpvOpUndef, out); 2091 ASSERT(lvalue); 2092 lvalue->store(result, out); 2093 return result; 2094 } 2095 case Token::BITWISEOREQ: { 2096 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2097 SpvOpUndef, SpvOpBitwiseOr, SpvOpBitwiseOr, 2098 SpvOpUndef, out); 2099 ASSERT(lvalue); 2100 lvalue->store(result, out); 2101 return result; 2102 } 2103 case Token::BITWISEXOREQ: { 2104 SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, 2105 SpvOpUndef, SpvOpBitwiseXor, SpvOpBitwiseXor, 2106 SpvOpUndef, out); 2107 ASSERT(lvalue); 2108 lvalue->store(result, out); 2109 return result; 2110 } 2111 default: 2112 ABORT("unsupported binary expression: %s", b.description().c_str()); 2113 } 2114 } 2115 2116 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, OutputStream& out) { 2117 ASSERT(a.fOperator == Token::LOGICALAND); 2118 BoolLiteral falseLiteral(fContext, -1, false); 2119 SpvId falseConstant = this->writeBoolLiteral(falseLiteral); 2120 SpvId lhs = this->writeExpression(*a.fLeft, out); 2121 SpvId rhsLabel = this->nextId(); 2122 SpvId end = this->nextId(); 2123 SpvId lhsBlock = fCurrentBlock; 2124 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2125 this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out); 2126 this->writeLabel(rhsLabel, out); 2127 SpvId rhs = this->writeExpression(*a.fRight, out); 2128 SpvId rhsBlock = fCurrentBlock; 2129 this->writeInstruction(SpvOpBranch, end, out); 2130 this->writeLabel(end, out); 2131 SpvId result = this->nextId(); 2132 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant, 2133 lhsBlock, rhs, rhsBlock, out); 2134 return result; 2135 } 2136 2137 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, OutputStream& out) { 2138 ASSERT(o.fOperator == Token::LOGICALOR); 2139 BoolLiteral trueLiteral(fContext, -1, true); 2140 SpvId trueConstant = this->writeBoolLiteral(trueLiteral); 2141 SpvId lhs = this->writeExpression(*o.fLeft, out); 2142 SpvId rhsLabel = this->nextId(); 2143 SpvId end = this->nextId(); 2144 SpvId lhsBlock = fCurrentBlock; 2145 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2146 this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out); 2147 this->writeLabel(rhsLabel, out); 2148 SpvId rhs = this->writeExpression(*o.fRight, out); 2149 SpvId rhsBlock = fCurrentBlock; 2150 this->writeInstruction(SpvOpBranch, end, out); 2151 this->writeLabel(end, out); 2152 SpvId result = this->nextId(); 2153 this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant, 2154 lhsBlock, rhs, rhsBlock, out); 2155 return result; 2156 } 2157 2158 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, OutputStream& out) { 2159 SpvId test = this->writeExpression(*t.fTest, out); 2160 if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) { 2161 // both true and false are constants, can just use OpSelect 2162 SpvId result = this->nextId(); 2163 SpvId trueId = this->writeExpression(*t.fIfTrue, out); 2164 SpvId falseId = this->writeExpression(*t.fIfFalse, out); 2165 this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId, 2166 out); 2167 return result; 2168 } 2169 // was originally using OpPhi to choose the result, but for some reason that is crashing on 2170 // Adreno. Switched to storing the result in a temp variable as glslang does. 2171 SpvId var = this->nextId(); 2172 this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction), 2173 var, SpvStorageClassFunction, fVariableBuffer); 2174 SpvId trueLabel = this->nextId(); 2175 SpvId falseLabel = this->nextId(); 2176 SpvId end = this->nextId(); 2177 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2178 this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out); 2179 this->writeLabel(trueLabel, out); 2180 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out); 2181 this->writeInstruction(SpvOpBranch, end, out); 2182 this->writeLabel(falseLabel, out); 2183 this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out); 2184 this->writeInstruction(SpvOpBranch, end, out); 2185 this->writeLabel(end, out); 2186 SpvId result = this->nextId(); 2187 this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out); 2188 return result; 2189 } 2190 2191 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) { 2192 if (type.isInteger()) { 2193 return std::unique_ptr<Expression>(new IntLiteral(context, -1, 1, &type)); 2194 } 2195 else if (type.isFloat()) { 2196 return std::unique_ptr<Expression>(new FloatLiteral(context, -1, 1.0, &type)); 2197 } else { 2198 ABORT("math is unsupported on type '%s'", type.name().c_str()); 2199 } 2200 } 2201 2202 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, OutputStream& out) { 2203 if (p.fOperator == Token::MINUS) { 2204 SpvId result = this->nextId(); 2205 SpvId typeId = this->getType(p.fType); 2206 SpvId expr = this->writeExpression(*p.fOperand, out); 2207 if (is_float(fContext, p.fType)) { 2208 this->writeInstruction(SpvOpFNegate, typeId, result, expr, out); 2209 } else if (is_signed(fContext, p.fType)) { 2210 this->writeInstruction(SpvOpSNegate, typeId, result, expr, out); 2211 } else { 2212 ABORT("unsupported prefix expression %s", p.description().c_str()); 2213 }; 2214 return result; 2215 } 2216 switch (p.fOperator) { 2217 case Token::PLUS: 2218 return this->writeExpression(*p.fOperand, out); 2219 case Token::PLUSPLUS: { 2220 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2221 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2222 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, 2223 SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef, 2224 out); 2225 lv->store(result, out); 2226 return result; 2227 } 2228 case Token::MINUSMINUS: { 2229 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2230 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2231 SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one, 2232 SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef, 2233 out); 2234 lv->store(result, out); 2235 return result; 2236 } 2237 case Token::LOGICALNOT: { 2238 ASSERT(p.fOperand->fType == *fContext.fBool_Type); 2239 SpvId result = this->nextId(); 2240 this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result, 2241 this->writeExpression(*p.fOperand, out), out); 2242 return result; 2243 } 2244 case Token::BITWISENOT: { 2245 SpvId result = this->nextId(); 2246 this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result, 2247 this->writeExpression(*p.fOperand, out), out); 2248 return result; 2249 } 2250 default: 2251 ABORT("unsupported prefix expression: %s", p.description().c_str()); 2252 } 2253 } 2254 2255 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, OutputStream& out) { 2256 std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out); 2257 SpvId result = lv->load(out); 2258 SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out); 2259 switch (p.fOperator) { 2260 case Token::PLUSPLUS: { 2261 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd, 2262 SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out); 2263 lv->store(temp, out); 2264 return result; 2265 } 2266 case Token::MINUSMINUS: { 2267 SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub, 2268 SpvOpISub, SpvOpISub, SpvOpUndef, out); 2269 lv->store(temp, out); 2270 return result; 2271 } 2272 default: 2273 ABORT("unsupported postfix expression %s", p.description().c_str()); 2274 } 2275 } 2276 2277 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) { 2278 if (b.fValue) { 2279 if (fBoolTrue == 0) { 2280 fBoolTrue = this->nextId(); 2281 this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue, 2282 fConstantBuffer); 2283 } 2284 return fBoolTrue; 2285 } else { 2286 if (fBoolFalse == 0) { 2287 fBoolFalse = this->nextId(); 2288 this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse, 2289 fConstantBuffer); 2290 } 2291 return fBoolFalse; 2292 } 2293 } 2294 2295 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) { 2296 if (i.fType == *fContext.fInt_Type) { 2297 auto entry = fIntConstants.find(i.fValue); 2298 if (entry == fIntConstants.end()) { 2299 SpvId result = this->nextId(); 2300 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, 2301 fConstantBuffer); 2302 fIntConstants[i.fValue] = result; 2303 return result; 2304 } 2305 return entry->second; 2306 } else { 2307 ASSERT(i.fType == *fContext.fUInt_Type); 2308 auto entry = fUIntConstants.find(i.fValue); 2309 if (entry == fUIntConstants.end()) { 2310 SpvId result = this->nextId(); 2311 this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue, 2312 fConstantBuffer); 2313 fUIntConstants[i.fValue] = result; 2314 return result; 2315 } 2316 return entry->second; 2317 } 2318 } 2319 2320 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) { 2321 if (f.fType == *fContext.fFloat_Type || f.fType == *fContext.fHalf_Type) { 2322 float value = (float) f.fValue; 2323 auto entry = fFloatConstants.find(value); 2324 if (entry == fFloatConstants.end()) { 2325 SpvId result = this->nextId(); 2326 uint32_t bits; 2327 ASSERT(sizeof(bits) == sizeof(value)); 2328 memcpy(&bits, &value, sizeof(bits)); 2329 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits, 2330 fConstantBuffer); 2331 fFloatConstants[value] = result; 2332 return result; 2333 } 2334 return entry->second; 2335 } else { 2336 ASSERT(f.fType == *fContext.fDouble_Type); 2337 auto entry = fDoubleConstants.find(f.fValue); 2338 if (entry == fDoubleConstants.end()) { 2339 SpvId result = this->nextId(); 2340 uint64_t bits; 2341 ASSERT(sizeof(bits) == sizeof(f.fValue)); 2342 memcpy(&bits, &f.fValue, sizeof(bits)); 2343 this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, 2344 bits & 0xffffffff, bits >> 32, fConstantBuffer); 2345 fDoubleConstants[f.fValue] = result; 2346 return result; 2347 } 2348 return entry->second; 2349 } 2350 } 2351 2352 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, OutputStream& out) { 2353 SpvId result = fFunctionMap[&f]; 2354 this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result, 2355 SpvFunctionControlMaskNone, this->getFunctionType(f), out); 2356 this->writeInstruction(SpvOpName, result, f.fName, fNameBuffer); 2357 for (size_t i = 0; i < f.fParameters.size(); i++) { 2358 SpvId id = this->nextId(); 2359 fVariableMap[f.fParameters[i]] = id; 2360 SpvId type; 2361 type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction); 2362 this->writeInstruction(SpvOpFunctionParameter, type, id, out); 2363 } 2364 return result; 2365 } 2366 2367 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, OutputStream& out) { 2368 fVariableBuffer.reset(); 2369 SpvId result = this->writeFunctionStart(f.fDeclaration, out); 2370 this->writeLabel(this->nextId(), out); 2371 if (f.fDeclaration.fName == "main") { 2372 write_stringstream(fGlobalInitializersBuffer, out); 2373 } 2374 StringStream bodyBuffer; 2375 this->writeBlock((Block&) *f.fBody, bodyBuffer); 2376 write_stringstream(fVariableBuffer, out); 2377 write_stringstream(bodyBuffer, out); 2378 if (fCurrentBlock) { 2379 if (f.fDeclaration.fReturnType == *fContext.fVoid_Type) { 2380 this->writeInstruction(SpvOpReturn, out); 2381 } else { 2382 this->writeInstruction(SpvOpUnreachable, out); 2383 } 2384 } 2385 this->writeInstruction(SpvOpFunctionEnd, out); 2386 return result; 2387 } 2388 2389 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) { 2390 if (layout.fLocation >= 0) { 2391 this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation, 2392 fDecorationBuffer); 2393 } 2394 if (layout.fBinding >= 0) { 2395 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding, 2396 fDecorationBuffer); 2397 } 2398 if (layout.fIndex >= 0) { 2399 this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex, 2400 fDecorationBuffer); 2401 } 2402 if (layout.fSet >= 0) { 2403 this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet, 2404 fDecorationBuffer); 2405 } 2406 if (layout.fInputAttachmentIndex >= 0) { 2407 this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex, 2408 layout.fInputAttachmentIndex, fDecorationBuffer); 2409 fCapabilities |= (((uint64_t) 1) << SpvCapabilityInputAttachment); 2410 } 2411 if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN && 2412 layout.fBuiltin != SK_IN_BUILTIN) { 2413 this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin, 2414 fDecorationBuffer); 2415 } 2416 } 2417 2418 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) { 2419 if (layout.fLocation >= 0) { 2420 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation, 2421 layout.fLocation, fDecorationBuffer); 2422 } 2423 if (layout.fBinding >= 0) { 2424 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding, 2425 layout.fBinding, fDecorationBuffer); 2426 } 2427 if (layout.fIndex >= 0) { 2428 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex, 2429 layout.fIndex, fDecorationBuffer); 2430 } 2431 if (layout.fSet >= 0) { 2432 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet, 2433 layout.fSet, fDecorationBuffer); 2434 } 2435 if (layout.fInputAttachmentIndex >= 0) { 2436 this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex, 2437 layout.fInputAttachmentIndex, fDecorationBuffer); 2438 } 2439 if (layout.fBuiltin >= 0) { 2440 this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn, 2441 layout.fBuiltin, fDecorationBuffer); 2442 } 2443 } 2444 2445 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) { 2446 bool isBuffer = (0 != (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag)); 2447 bool pushConstant = (0 != (intf.fVariable.fModifiers.fLayout.fFlags & 2448 Layout::kPushConstant_Flag)); 2449 MemoryLayout memoryLayout = (pushConstant || isBuffer) ? 2450 MemoryLayout(MemoryLayout::k430_Standard) : 2451 fDefaultLayout; 2452 SpvId result = this->nextId(); 2453 const Type* type = &intf.fVariable.fType; 2454 if (fProgram.fInputs.fRTHeight) { 2455 ASSERT(fRTHeightStructId == (SpvId) -1); 2456 ASSERT(fRTHeightFieldIndex == (SpvId) -1); 2457 std::vector<Type::Field> fields = type->fields(); 2458 fRTHeightStructId = result; 2459 fRTHeightFieldIndex = fields.size(); 2460 fields.emplace_back(Modifiers(), StringFragment(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get()); 2461 type = new Type(type->fOffset, type->name(), fields); 2462 } 2463 SpvId typeId = this->getType(*type, memoryLayout); 2464 if (intf.fVariable.fModifiers.fFlags & Modifiers::kBuffer_Flag) { 2465 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBufferBlock, fDecorationBuffer); 2466 } else { 2467 this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer); 2468 } 2469 SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers); 2470 SpvId ptrType = this->nextId(); 2471 this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer); 2472 this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer); 2473 Layout layout = intf.fVariable.fModifiers.fLayout; 2474 if (intf.fVariable.fModifiers.fFlags & Modifiers::kUniform_Flag && layout.fSet == -1) { 2475 layout.fSet = 0; 2476 } 2477 this->writeLayout(layout, result); 2478 fVariableMap[&intf.fVariable] = result; 2479 if (fProgram.fInputs.fRTHeight) { 2480 delete type; 2481 } 2482 return result; 2483 } 2484 2485 void SPIRVCodeGenerator::writePrecisionModifier(const Modifiers& modifiers, SpvId id) { 2486 if ((modifiers.fFlags & Modifiers::kLowp_Flag) | 2487 (modifiers.fFlags & Modifiers::kMediump_Flag)) { 2488 this->writeInstruction(SpvOpDecorate, id, SpvDecorationRelaxedPrecision, fDecorationBuffer); 2489 } 2490 } 2491 2492 #define BUILTIN_IGNORE 9999 2493 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl, 2494 OutputStream& out) { 2495 for (size_t i = 0; i < decl.fVars.size(); i++) { 2496 if (decl.fVars[i]->fKind == Statement::kNop_Kind) { 2497 continue; 2498 } 2499 const VarDeclaration& varDecl = (VarDeclaration&) *decl.fVars[i]; 2500 const Variable* var = varDecl.fVar; 2501 // These haven't been implemented in our SPIR-V generator yet and we only currently use them 2502 // in the OpenGL backend. 2503 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag | 2504 Modifiers::kWriteOnly_Flag | 2505 Modifiers::kCoherent_Flag | 2506 Modifiers::kVolatile_Flag | 2507 Modifiers::kRestrict_Flag))); 2508 if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) { 2509 continue; 2510 } 2511 if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN && 2512 kind != Program::kFragment_Kind) { 2513 ASSERT(!fProgram.fSettings.fFragColorIsInOut); 2514 continue; 2515 } 2516 if (!var->fReadCount && !var->fWriteCount && 2517 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag | 2518 Modifiers::kOut_Flag | 2519 Modifiers::kUniform_Flag | 2520 Modifiers::kBuffer_Flag))) { 2521 // variable is dead and not an input / output var (the Vulkan debug layers complain if 2522 // we elide an interface var, even if it's dead) 2523 continue; 2524 } 2525 SpvStorageClass_ storageClass; 2526 if (var->fModifiers.fFlags & Modifiers::kIn_Flag) { 2527 storageClass = SpvStorageClassInput; 2528 } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) { 2529 storageClass = SpvStorageClassOutput; 2530 } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) { 2531 if (var->fType.kind() == Type::kSampler_Kind) { 2532 storageClass = SpvStorageClassUniformConstant; 2533 } else { 2534 storageClass = SpvStorageClassUniform; 2535 } 2536 } else { 2537 storageClass = SpvStorageClassPrivate; 2538 } 2539 SpvId id = this->nextId(); 2540 fVariableMap[var] = id; 2541 SpvId type = this->getPointerType(var->fType, storageClass); 2542 this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer); 2543 this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer); 2544 this->writePrecisionModifier(var->fModifiers, id); 2545 if (varDecl.fValue) { 2546 ASSERT(!fCurrentBlock); 2547 fCurrentBlock = -1; 2548 SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer); 2549 this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer); 2550 fCurrentBlock = 0; 2551 } 2552 this->writeLayout(var->fModifiers.fLayout, id); 2553 if (var->fModifiers.fFlags & Modifiers::kFlat_Flag) { 2554 this->writeInstruction(SpvOpDecorate, id, SpvDecorationFlat, fDecorationBuffer); 2555 } 2556 if (var->fModifiers.fFlags & Modifiers::kNoPerspective_Flag) { 2557 this->writeInstruction(SpvOpDecorate, id, SpvDecorationNoPerspective, 2558 fDecorationBuffer); 2559 } 2560 } 2561 } 2562 2563 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, OutputStream& out) { 2564 for (const auto& stmt : decl.fVars) { 2565 ASSERT(stmt->fKind == Statement::kVarDeclaration_Kind); 2566 VarDeclaration& varDecl = (VarDeclaration&) *stmt; 2567 const Variable* var = varDecl.fVar; 2568 // These haven't been implemented in our SPIR-V generator yet and we only currently use them 2569 // in the OpenGL backend. 2570 ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag | 2571 Modifiers::kWriteOnly_Flag | 2572 Modifiers::kCoherent_Flag | 2573 Modifiers::kVolatile_Flag | 2574 Modifiers::kRestrict_Flag))); 2575 SpvId id = this->nextId(); 2576 fVariableMap[var] = id; 2577 SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction); 2578 this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer); 2579 this->writeInstruction(SpvOpName, id, var->fName, fNameBuffer); 2580 if (varDecl.fValue) { 2581 SpvId value = this->writeExpression(*varDecl.fValue, out); 2582 this->writeInstruction(SpvOpStore, id, value, out); 2583 } 2584 } 2585 } 2586 2587 void SPIRVCodeGenerator::writeStatement(const Statement& s, OutputStream& out) { 2588 switch (s.fKind) { 2589 case Statement::kNop_Kind: 2590 break; 2591 case Statement::kBlock_Kind: 2592 this->writeBlock((Block&) s, out); 2593 break; 2594 case Statement::kExpression_Kind: 2595 this->writeExpression(*((ExpressionStatement&) s).fExpression, out); 2596 break; 2597 case Statement::kReturn_Kind: 2598 this->writeReturnStatement((ReturnStatement&) s, out); 2599 break; 2600 case Statement::kVarDeclarations_Kind: 2601 this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out); 2602 break; 2603 case Statement::kIf_Kind: 2604 this->writeIfStatement((IfStatement&) s, out); 2605 break; 2606 case Statement::kFor_Kind: 2607 this->writeForStatement((ForStatement&) s, out); 2608 break; 2609 case Statement::kWhile_Kind: 2610 this->writeWhileStatement((WhileStatement&) s, out); 2611 break; 2612 case Statement::kDo_Kind: 2613 this->writeDoStatement((DoStatement&) s, out); 2614 break; 2615 case Statement::kSwitch_Kind: 2616 this->writeSwitchStatement((SwitchStatement&) s, out); 2617 break; 2618 case Statement::kBreak_Kind: 2619 this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out); 2620 break; 2621 case Statement::kContinue_Kind: 2622 this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out); 2623 break; 2624 case Statement::kDiscard_Kind: 2625 this->writeInstruction(SpvOpKill, out); 2626 break; 2627 default: 2628 ABORT("unsupported statement: %s", s.description().c_str()); 2629 } 2630 } 2631 2632 void SPIRVCodeGenerator::writeBlock(const Block& b, OutputStream& out) { 2633 for (size_t i = 0; i < b.fStatements.size(); i++) { 2634 this->writeStatement(*b.fStatements[i], out); 2635 } 2636 } 2637 2638 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, OutputStream& out) { 2639 SpvId test = this->writeExpression(*stmt.fTest, out); 2640 SpvId ifTrue = this->nextId(); 2641 SpvId ifFalse = this->nextId(); 2642 if (stmt.fIfFalse) { 2643 SpvId end = this->nextId(); 2644 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2645 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out); 2646 this->writeLabel(ifTrue, out); 2647 this->writeStatement(*stmt.fIfTrue, out); 2648 if (fCurrentBlock) { 2649 this->writeInstruction(SpvOpBranch, end, out); 2650 } 2651 this->writeLabel(ifFalse, out); 2652 this->writeStatement(*stmt.fIfFalse, out); 2653 if (fCurrentBlock) { 2654 this->writeInstruction(SpvOpBranch, end, out); 2655 } 2656 this->writeLabel(end, out); 2657 } else { 2658 this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out); 2659 this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out); 2660 this->writeLabel(ifTrue, out); 2661 this->writeStatement(*stmt.fIfTrue, out); 2662 if (fCurrentBlock) { 2663 this->writeInstruction(SpvOpBranch, ifFalse, out); 2664 } 2665 this->writeLabel(ifFalse, out); 2666 } 2667 } 2668 2669 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, OutputStream& out) { 2670 if (f.fInitializer) { 2671 this->writeStatement(*f.fInitializer, out); 2672 } 2673 SpvId header = this->nextId(); 2674 SpvId start = this->nextId(); 2675 SpvId body = this->nextId(); 2676 SpvId next = this->nextId(); 2677 fContinueTarget.push(next); 2678 SpvId end = this->nextId(); 2679 fBreakTarget.push(end); 2680 this->writeInstruction(SpvOpBranch, header, out); 2681 this->writeLabel(header, out); 2682 this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out); 2683 this->writeInstruction(SpvOpBranch, start, out); 2684 this->writeLabel(start, out); 2685 if (f.fTest) { 2686 SpvId test = this->writeExpression(*f.fTest, out); 2687 this->writeInstruction(SpvOpBranchConditional, test, body, end, out); 2688 } 2689 this->writeLabel(body, out); 2690 this->writeStatement(*f.fStatement, out); 2691 if (fCurrentBlock) { 2692 this->writeInstruction(SpvOpBranch, next, out); 2693 } 2694 this->writeLabel(next, out); 2695 if (f.fNext) { 2696 this->writeExpression(*f.fNext, out); 2697 } 2698 this->writeInstruction(SpvOpBranch, header, out); 2699 this->writeLabel(end, out); 2700 fBreakTarget.pop(); 2701 fContinueTarget.pop(); 2702 } 2703 2704 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, OutputStream& out) { 2705 // We believe the while loop code below will work, but Skia doesn't actually use them and 2706 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For 2707 // the time being, we just fail with an error due to the lack of testing. If you encounter this 2708 // message, simply remove the error call below to see whether our while loop support actually 2709 // works. 2710 fErrors.error(w.fOffset, "internal error: while loop support has been disabled in SPIR-V, " 2711 "see SkSLSPIRVCodeGenerator.cpp for details"); 2712 2713 SpvId header = this->nextId(); 2714 SpvId start = this->nextId(); 2715 SpvId body = this->nextId(); 2716 fContinueTarget.push(start); 2717 SpvId end = this->nextId(); 2718 fBreakTarget.push(end); 2719 this->writeInstruction(SpvOpBranch, header, out); 2720 this->writeLabel(header, out); 2721 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out); 2722 this->writeInstruction(SpvOpBranch, start, out); 2723 this->writeLabel(start, out); 2724 SpvId test = this->writeExpression(*w.fTest, out); 2725 this->writeInstruction(SpvOpBranchConditional, test, body, end, out); 2726 this->writeLabel(body, out); 2727 this->writeStatement(*w.fStatement, out); 2728 if (fCurrentBlock) { 2729 this->writeInstruction(SpvOpBranch, start, out); 2730 } 2731 this->writeLabel(end, out); 2732 fBreakTarget.pop(); 2733 fContinueTarget.pop(); 2734 } 2735 2736 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, OutputStream& out) { 2737 // We believe the do loop code below will work, but Skia doesn't actually use them and 2738 // adequately testing this code in the absence of Skia exercising it isn't straightforward. For 2739 // the time being, we just fail with an error due to the lack of testing. If you encounter this 2740 // message, simply remove the error call below to see whether our do loop support actually 2741 // works. 2742 fErrors.error(d.fOffset, "internal error: do loop support has been disabled in SPIR-V, see " 2743 "SkSLSPIRVCodeGenerator.cpp for details"); 2744 2745 SpvId header = this->nextId(); 2746 SpvId start = this->nextId(); 2747 SpvId next = this->nextId(); 2748 fContinueTarget.push(next); 2749 SpvId end = this->nextId(); 2750 fBreakTarget.push(end); 2751 this->writeInstruction(SpvOpBranch, header, out); 2752 this->writeLabel(header, out); 2753 this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out); 2754 this->writeInstruction(SpvOpBranch, start, out); 2755 this->writeLabel(start, out); 2756 this->writeStatement(*d.fStatement, out); 2757 if (fCurrentBlock) { 2758 this->writeInstruction(SpvOpBranch, next, out); 2759 } 2760 this->writeLabel(next, out); 2761 SpvId test = this->writeExpression(*d.fTest, out); 2762 this->writeInstruction(SpvOpBranchConditional, test, start, end, out); 2763 this->writeLabel(end, out); 2764 fBreakTarget.pop(); 2765 fContinueTarget.pop(); 2766 } 2767 2768 void SPIRVCodeGenerator::writeSwitchStatement(const SwitchStatement& s, OutputStream& out) { 2769 SpvId value = this->writeExpression(*s.fValue, out); 2770 std::vector<SpvId> labels; 2771 SpvId end = this->nextId(); 2772 SpvId defaultLabel = end; 2773 fBreakTarget.push(end); 2774 int size = 3; 2775 for (const auto& c : s.fCases) { 2776 SpvId label = this->nextId(); 2777 labels.push_back(label); 2778 if (c->fValue) { 2779 size += 2; 2780 } else { 2781 defaultLabel = label; 2782 } 2783 } 2784 labels.push_back(end); 2785 this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out); 2786 this->writeOpCode(SpvOpSwitch, size, out); 2787 this->writeWord(value, out); 2788 this->writeWord(defaultLabel, out); 2789 for (size_t i = 0; i < s.fCases.size(); ++i) { 2790 if (!s.fCases[i]->fValue) { 2791 continue; 2792 } 2793 ASSERT(s.fCases[i]->fValue->fKind == Expression::kIntLiteral_Kind); 2794 this->writeWord(((IntLiteral&) *s.fCases[i]->fValue).fValue, out); 2795 this->writeWord(labels[i], out); 2796 } 2797 for (size_t i = 0; i < s.fCases.size(); ++i) { 2798 this->writeLabel(labels[i], out); 2799 for (const auto& stmt : s.fCases[i]->fStatements) { 2800 this->writeStatement(*stmt, out); 2801 } 2802 if (fCurrentBlock) { 2803 this->writeInstruction(SpvOpBranch, labels[i + 1], out); 2804 } 2805 } 2806 this->writeLabel(end, out); 2807 fBreakTarget.pop(); 2808 } 2809 2810 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, OutputStream& out) { 2811 if (r.fExpression) { 2812 this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out), 2813 out); 2814 } else { 2815 this->writeInstruction(SpvOpReturn, out); 2816 } 2817 } 2818 2819 void SPIRVCodeGenerator::writeGeometryShaderExecutionMode(SpvId entryPoint, OutputStream& out) { 2820 ASSERT(fProgram.fKind == Program::kGeometry_Kind); 2821 int invocations = 1; 2822 for (size_t i = 0; i < fProgram.fElements.size(); i++) { 2823 if (fProgram.fElements[i]->fKind == ProgramElement::kModifiers_Kind) { 2824 const Modifiers& m = ((ModifiersDeclaration&) *fProgram.fElements[i]).fModifiers; 2825 if (m.fFlags & Modifiers::kIn_Flag) { 2826 if (m.fLayout.fInvocations != -1) { 2827 invocations = m.fLayout.fInvocations; 2828 } 2829 SpvId input; 2830 switch (m.fLayout.fPrimitive) { 2831 case Layout::kPoints_Primitive: 2832 input = SpvExecutionModeInputPoints; 2833 break; 2834 case Layout::kLines_Primitive: 2835 input = SpvExecutionModeInputLines; 2836 break; 2837 case Layout::kLinesAdjacency_Primitive: 2838 input = SpvExecutionModeInputLinesAdjacency; 2839 break; 2840 case Layout::kTriangles_Primitive: 2841 input = SpvExecutionModeTriangles; 2842 break; 2843 case Layout::kTrianglesAdjacency_Primitive: 2844 input = SpvExecutionModeInputTrianglesAdjacency; 2845 break; 2846 default: 2847 input = 0; 2848 break; 2849 } 2850 if (input) { 2851 this->writeInstruction(SpvOpExecutionMode, entryPoint, input, out); 2852 } 2853 } else if (m.fFlags & Modifiers::kOut_Flag) { 2854 SpvId output; 2855 switch (m.fLayout.fPrimitive) { 2856 case Layout::kPoints_Primitive: 2857 output = SpvExecutionModeOutputPoints; 2858 break; 2859 case Layout::kLineStrip_Primitive: 2860 output = SpvExecutionModeOutputLineStrip; 2861 break; 2862 case Layout::kTriangleStrip_Primitive: 2863 output = SpvExecutionModeOutputTriangleStrip; 2864 break; 2865 default: 2866 output = 0; 2867 break; 2868 } 2869 if (output) { 2870 this->writeInstruction(SpvOpExecutionMode, entryPoint, output, out); 2871 } 2872 if (m.fLayout.fMaxVertices != -1) { 2873 this->writeInstruction(SpvOpExecutionMode, entryPoint, 2874 SpvExecutionModeOutputVertices, m.fLayout.fMaxVertices, 2875 out); 2876 } 2877 } 2878 } 2879 } 2880 this->writeInstruction(SpvOpExecutionMode, entryPoint, SpvExecutionModeInvocations, 2881 invocations, out); 2882 } 2883 2884 void SPIRVCodeGenerator::writeInstructions(const Program& program, OutputStream& out) { 2885 fGLSLExtendedInstructions = this->nextId(); 2886 StringStream body; 2887 std::set<SpvId> interfaceVars; 2888 // assign IDs to functions, determine sk_in size 2889 int skInSize = -1; 2890 for (size_t i = 0; i < program.fElements.size(); i++) { 2891 switch (program.fElements[i]->fKind) { 2892 case ProgramElement::kFunction_Kind: { 2893 FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i]; 2894 fFunctionMap[&f.fDeclaration] = this->nextId(); 2895 break; 2896 } 2897 case ProgramElement::kModifiers_Kind: { 2898 Modifiers& m = ((ModifiersDeclaration&) *program.fElements[i]).fModifiers; 2899 if (m.fFlags & Modifiers::kIn_Flag) { 2900 switch (m.fLayout.fPrimitive) { 2901 case Layout::kPoints_Primitive: // break 2902 case Layout::kLines_Primitive: 2903 skInSize = 1; 2904 break; 2905 case Layout::kLinesAdjacency_Primitive: // break 2906 skInSize = 2; 2907 break; 2908 case Layout::kTriangles_Primitive: // break 2909 case Layout::kTrianglesAdjacency_Primitive: 2910 skInSize = 3; 2911 break; 2912 default: 2913 break; 2914 } 2915 } 2916 break; 2917 } 2918 default: 2919 break; 2920 } 2921 } 2922 for (size_t i = 0; i < program.fElements.size(); i++) { 2923 if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) { 2924 InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i]; 2925 if (SK_IN_BUILTIN == intf.fVariable.fModifiers.fLayout.fBuiltin) { 2926 ASSERT(skInSize != -1); 2927 intf.fSizes.emplace_back(new IntLiteral(fContext, -1, skInSize)); 2928 } 2929 SpvId id = this->writeInterfaceBlock(intf); 2930 if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) || 2931 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) { 2932 interfaceVars.insert(id); 2933 } 2934 } 2935 } 2936 for (size_t i = 0; i < program.fElements.size(); i++) { 2937 if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) { 2938 this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]), 2939 body); 2940 } 2941 } 2942 for (size_t i = 0; i < program.fElements.size(); i++) { 2943 if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) { 2944 this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body); 2945 } 2946 } 2947 const FunctionDeclaration* main = nullptr; 2948 for (auto entry : fFunctionMap) { 2949 if (entry.first->fName == "main") { 2950 main = entry.first; 2951 } 2952 } 2953 ASSERT(main); 2954 for (auto entry : fVariableMap) { 2955 const Variable* var = entry.first; 2956 if (var->fStorage == Variable::kGlobal_Storage && 2957 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) || 2958 (var->fModifiers.fFlags & Modifiers::kOut_Flag))) { 2959 interfaceVars.insert(entry.second); 2960 } 2961 } 2962 this->writeCapabilities(out); 2963 this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out); 2964 this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out); 2965 this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (main->fName.fLength + 4) / 4) + 2966 (int32_t) interfaceVars.size(), out); 2967 switch (program.fKind) { 2968 case Program::kVertex_Kind: 2969 this->writeWord(SpvExecutionModelVertex, out); 2970 break; 2971 case Program::kFragment_Kind: 2972 this->writeWord(SpvExecutionModelFragment, out); 2973 break; 2974 case Program::kGeometry_Kind: 2975 this->writeWord(SpvExecutionModelGeometry, out); 2976 break; 2977 default: 2978 ABORT("cannot write this kind of program to SPIR-V\n"); 2979 } 2980 SpvId entryPoint = fFunctionMap[main]; 2981 this->writeWord(entryPoint, out); 2982 this->writeString(main->fName.fChars, main->fName.fLength, out); 2983 for (int var : interfaceVars) { 2984 this->writeWord(var, out); 2985 } 2986 if (program.fKind == Program::kGeometry_Kind) { 2987 this->writeGeometryShaderExecutionMode(entryPoint, out); 2988 } 2989 if (program.fKind == Program::kFragment_Kind) { 2990 this->writeInstruction(SpvOpExecutionMode, 2991 fFunctionMap[main], 2992 SpvExecutionModeOriginUpperLeft, 2993 out); 2994 } 2995 for (size_t i = 0; i < program.fElements.size(); i++) { 2996 if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) { 2997 this->writeInstruction(SpvOpSourceExtension, 2998 ((Extension&) *program.fElements[i]).fName.c_str(), 2999 out); 3000 } 3001 } 3002 3003 write_stringstream(fExtraGlobalsBuffer, out); 3004 write_stringstream(fNameBuffer, out); 3005 write_stringstream(fDecorationBuffer, out); 3006 write_stringstream(fConstantBuffer, out); 3007 write_stringstream(fExternalFunctionsBuffer, out); 3008 write_stringstream(body, out); 3009 } 3010 3011 bool SPIRVCodeGenerator::generateCode() { 3012 ASSERT(!fErrors.errorCount()); 3013 this->writeWord(SpvMagicNumber, *fOut); 3014 this->writeWord(SpvVersion, *fOut); 3015 this->writeWord(SKSL_MAGIC, *fOut); 3016 StringStream buffer; 3017 this->writeInstructions(fProgram, buffer); 3018 this->writeWord(fIdCount, *fOut); 3019 this->writeWord(0, *fOut); // reserved, always zero 3020 write_stringstream(buffer, *fOut); 3021 return 0 == fErrors.errorCount(); 3022 } 3023 3024 } 3025