Home | History | Annotate | Download | only in sksl
      1 /*
      2  * Copyright 2016 Google Inc.
      3  *
      4  * Use of this source code is governed by a BSD-style license that can be
      5  * found in the LICENSE file.
      6  */
      7 
      8 #include "SkSLSPIRVCodeGenerator.h"
      9 
     10 #include "string.h"
     11 
     12 #include "GLSL.std.450.h"
     13 
     14 #include "ir/SkSLExpressionStatement.h"
     15 #include "ir/SkSLExtension.h"
     16 #include "ir/SkSLIndexExpression.h"
     17 #include "ir/SkSLVariableReference.h"
     18 #include "SkSLCompiler.h"
     19 
     20 namespace SkSL {
     21 
     22 #define SPIRV_DEBUG 0
     23 
     24 static const int32_t SKSL_MAGIC  = 0x0; // FIXME: we should probably register a magic number
     25 
     26 void SPIRVCodeGenerator::setupIntrinsics() {
     27 #define ALL_GLSL(x) std::make_tuple(kGLSL_STD_450_IntrinsicKind, GLSLstd450 ## x, GLSLstd450 ## x, \
     28                                     GLSLstd450 ## x, GLSLstd450 ## x)
     29 #define BY_TYPE_GLSL(ifFloat, ifInt, ifUInt) std::make_tuple(kGLSL_STD_450_IntrinsicKind, \
     30                                                              GLSLstd450 ## ifFloat, \
     31                                                              GLSLstd450 ## ifInt, \
     32                                                              GLSLstd450 ## ifUInt, \
     33                                                              SpvOpUndef)
     34 #define SPECIAL(x) std::make_tuple(kSpecial_IntrinsicKind, k ## x ## _SpecialIntrinsic, \
     35                                    k ## x ## _SpecialIntrinsic, k ## x ## _SpecialIntrinsic, \
     36                                    k ## x ## _SpecialIntrinsic)
     37     fIntrinsicMap[SkString("round")]         = ALL_GLSL(Round);
     38     fIntrinsicMap[SkString("roundEven")]     = ALL_GLSL(RoundEven);
     39     fIntrinsicMap[SkString("trunc")]         = ALL_GLSL(Trunc);
     40     fIntrinsicMap[SkString("abs")]           = BY_TYPE_GLSL(FAbs, SAbs, SAbs);
     41     fIntrinsicMap[SkString("sign")]          = BY_TYPE_GLSL(FSign, SSign, SSign);
     42     fIntrinsicMap[SkString("floor")]         = ALL_GLSL(Floor);
     43     fIntrinsicMap[SkString("ceil")]          = ALL_GLSL(Ceil);
     44     fIntrinsicMap[SkString("fract")]         = ALL_GLSL(Fract);
     45     fIntrinsicMap[SkString("radians")]       = ALL_GLSL(Radians);
     46     fIntrinsicMap[SkString("degrees")]       = ALL_GLSL(Degrees);
     47     fIntrinsicMap[SkString("sin")]           = ALL_GLSL(Sin);
     48     fIntrinsicMap[SkString("cos")]           = ALL_GLSL(Cos);
     49     fIntrinsicMap[SkString("tan")]           = ALL_GLSL(Tan);
     50     fIntrinsicMap[SkString("asin")]          = ALL_GLSL(Asin);
     51     fIntrinsicMap[SkString("acos")]          = ALL_GLSL(Acos);
     52     fIntrinsicMap[SkString("atan")]          = SPECIAL(Atan);
     53     fIntrinsicMap[SkString("sinh")]          = ALL_GLSL(Sinh);
     54     fIntrinsicMap[SkString("cosh")]          = ALL_GLSL(Cosh);
     55     fIntrinsicMap[SkString("tanh")]          = ALL_GLSL(Tanh);
     56     fIntrinsicMap[SkString("asinh")]         = ALL_GLSL(Asinh);
     57     fIntrinsicMap[SkString("acosh")]         = ALL_GLSL(Acosh);
     58     fIntrinsicMap[SkString("atanh")]         = ALL_GLSL(Atanh);
     59     fIntrinsicMap[SkString("pow")]           = ALL_GLSL(Pow);
     60     fIntrinsicMap[SkString("exp")]           = ALL_GLSL(Exp);
     61     fIntrinsicMap[SkString("log")]           = ALL_GLSL(Log);
     62     fIntrinsicMap[SkString("exp2")]          = ALL_GLSL(Exp2);
     63     fIntrinsicMap[SkString("log2")]          = ALL_GLSL(Log2);
     64     fIntrinsicMap[SkString("sqrt")]          = ALL_GLSL(Sqrt);
     65     fIntrinsicMap[SkString("inversesqrt")]   = ALL_GLSL(InverseSqrt);
     66     fIntrinsicMap[SkString("determinant")]   = ALL_GLSL(Determinant);
     67     fIntrinsicMap[SkString("matrixInverse")] = ALL_GLSL(MatrixInverse);
     68     fIntrinsicMap[SkString("mod")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpFMod,
     69                                                                SpvOpSMod, SpvOpUMod, SpvOpUndef);
     70     fIntrinsicMap[SkString("min")]           = BY_TYPE_GLSL(FMin, SMin, UMin);
     71     fIntrinsicMap[SkString("max")]           = BY_TYPE_GLSL(FMax, SMax, UMax);
     72     fIntrinsicMap[SkString("clamp")]         = BY_TYPE_GLSL(FClamp, SClamp, UClamp);
     73     fIntrinsicMap[SkString("dot")]           = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDot,
     74                                                                SpvOpUndef, SpvOpUndef, SpvOpUndef);
     75     fIntrinsicMap[SkString("mix")]           = ALL_GLSL(FMix);
     76     fIntrinsicMap[SkString("step")]          = ALL_GLSL(Step);
     77     fIntrinsicMap[SkString("smoothstep")]    = ALL_GLSL(SmoothStep);
     78     fIntrinsicMap[SkString("fma")]           = ALL_GLSL(Fma);
     79     fIntrinsicMap[SkString("frexp")]         = ALL_GLSL(Frexp);
     80     fIntrinsicMap[SkString("ldexp")]         = ALL_GLSL(Ldexp);
     81 
     82 #define PACK(type) fIntrinsicMap[SkString("pack" #type)] = ALL_GLSL(Pack ## type); \
     83                    fIntrinsicMap[SkString("unpack" #type)] = ALL_GLSL(Unpack ## type)
     84     PACK(Snorm4x8);
     85     PACK(Unorm4x8);
     86     PACK(Snorm2x16);
     87     PACK(Unorm2x16);
     88     PACK(Half2x16);
     89     PACK(Double2x32);
     90     fIntrinsicMap[SkString("length")]      = ALL_GLSL(Length);
     91     fIntrinsicMap[SkString("distance")]    = ALL_GLSL(Distance);
     92     fIntrinsicMap[SkString("cross")]       = ALL_GLSL(Cross);
     93     fIntrinsicMap[SkString("normalize")]   = ALL_GLSL(Normalize);
     94     fIntrinsicMap[SkString("faceForward")] = ALL_GLSL(FaceForward);
     95     fIntrinsicMap[SkString("reflect")]     = ALL_GLSL(Reflect);
     96     fIntrinsicMap[SkString("refract")]     = ALL_GLSL(Refract);
     97     fIntrinsicMap[SkString("findLSB")]     = ALL_GLSL(FindILsb);
     98     fIntrinsicMap[SkString("findMSB")]     = BY_TYPE_GLSL(FindSMsb, FindSMsb, FindUMsb);
     99     fIntrinsicMap[SkString("dFdx")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdx,
    100                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
    101     fIntrinsicMap[SkString("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
    102                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
    103     fIntrinsicMap[SkString("dFdy")]        = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpDPdy,
    104                                                              SpvOpUndef, SpvOpUndef, SpvOpUndef);
    105     fIntrinsicMap[SkString("texture")]     = SPECIAL(Texture);
    106 
    107     fIntrinsicMap[SkString("subpassLoad")] = SPECIAL(SubpassLoad);
    108 
    109     fIntrinsicMap[SkString("any")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
    110                                                                   SpvOpUndef, SpvOpUndef, SpvOpAny);
    111     fIntrinsicMap[SkString("all")]              = std::make_tuple(kSPIRV_IntrinsicKind, SpvOpUndef,
    112                                                                   SpvOpUndef, SpvOpUndef, SpvOpAll);
    113     fIntrinsicMap[SkString("equal")]            = std::make_tuple(kSPIRV_IntrinsicKind,
    114                                                                   SpvOpFOrdEqual, SpvOpIEqual,
    115                                                                   SpvOpIEqual, SpvOpLogicalEqual);
    116     fIntrinsicMap[SkString("notEqual")]         = std::make_tuple(kSPIRV_IntrinsicKind,
    117                                                                   SpvOpFOrdNotEqual, SpvOpINotEqual,
    118                                                                   SpvOpINotEqual,
    119                                                                   SpvOpLogicalNotEqual);
    120     fIntrinsicMap[SkString("lessThan")]         = std::make_tuple(kSPIRV_IntrinsicKind,
    121                                                                   SpvOpSLessThan, SpvOpULessThan,
    122                                                                   SpvOpFOrdLessThan, SpvOpUndef);
    123     fIntrinsicMap[SkString("lessThanEqual")]    = std::make_tuple(kSPIRV_IntrinsicKind,
    124                                                                   SpvOpSLessThanEqual,
    125                                                                   SpvOpULessThanEqual,
    126                                                                   SpvOpFOrdLessThanEqual,
    127                                                                   SpvOpUndef);
    128     fIntrinsicMap[SkString("greaterThan")]      = std::make_tuple(kSPIRV_IntrinsicKind,
    129                                                                   SpvOpSGreaterThan,
    130                                                                   SpvOpUGreaterThan,
    131                                                                   SpvOpFOrdGreaterThan,
    132                                                                   SpvOpUndef);
    133     fIntrinsicMap[SkString("greaterThanEqual")] = std::make_tuple(kSPIRV_IntrinsicKind,
    134                                                                   SpvOpSGreaterThanEqual,
    135                                                                   SpvOpUGreaterThanEqual,
    136                                                                   SpvOpFOrdGreaterThanEqual,
    137                                                                   SpvOpUndef);
    138 
    139 // interpolateAt* not yet supported...
    140 }
    141 
    142 void SPIRVCodeGenerator::writeWord(int32_t word, SkWStream& out) {
    143 #if SPIRV_DEBUG
    144     out << "(" << word << ") ";
    145 #else
    146     out.write((const char*) &word, sizeof(word));
    147 #endif
    148 }
    149 
    150 static bool is_float(const Context& context, const Type& type) {
    151     if (type.kind() == Type::kVector_Kind) {
    152         return is_float(context, type.componentType());
    153     }
    154     return type == *context.fFloat_Type || type == *context.fDouble_Type;
    155 }
    156 
    157 static bool is_signed(const Context& context, const Type& type) {
    158     if (type.kind() == Type::kVector_Kind) {
    159         return is_signed(context, type.componentType());
    160     }
    161     return type == *context.fInt_Type;
    162 }
    163 
    164 static bool is_unsigned(const Context& context, const Type& type) {
    165     if (type.kind() == Type::kVector_Kind) {
    166         return is_unsigned(context, type.componentType());
    167     }
    168     return type == *context.fUInt_Type;
    169 }
    170 
    171 static bool is_bool(const Context& context, const Type& type) {
    172     if (type.kind() == Type::kVector_Kind) {
    173         return is_bool(context, type.componentType());
    174     }
    175     return type == *context.fBool_Type;
    176 }
    177 
    178 static bool is_out(const Variable& var) {
    179     return (var.fModifiers.fFlags & Modifiers::kOut_Flag) != 0;
    180 }
    181 
    182 #if SPIRV_DEBUG
    183 static SkString opcode_text(SpvOp_ opCode) {
    184     switch (opCode) {
    185         case SpvOpNop:
    186             return SkString("Nop");
    187         case SpvOpUndef:
    188             return SkString("Undef");
    189         case SpvOpSourceContinued:
    190             return SkString("SourceContinued");
    191         case SpvOpSource:
    192             return SkString("Source");
    193         case SpvOpSourceExtension:
    194             return SkString("SourceExtension");
    195         case SpvOpName:
    196             return SkString("Name");
    197         case SpvOpMemberName:
    198             return SkString("MemberName");
    199         case SpvOpString:
    200             return SkString("String");
    201         case SpvOpLine:
    202             return SkString("Line");
    203         case SpvOpExtension:
    204             return SkString("Extension");
    205         case SpvOpExtInstImport:
    206             return SkString("ExtInstImport");
    207         case SpvOpExtInst:
    208             return SkString("ExtInst");
    209         case SpvOpMemoryModel:
    210             return SkString("MemoryModel");
    211         case SpvOpEntryPoint:
    212             return SkString("EntryPoint");
    213         case SpvOpExecutionMode:
    214             return SkString("ExecutionMode");
    215         case SpvOpCapability:
    216             return SkString("Capability");
    217         case SpvOpTypeVoid:
    218             return SkString("TypeVoid");
    219         case SpvOpTypeBool:
    220             return SkString("TypeBool");
    221         case SpvOpTypeInt:
    222             return SkString("TypeInt");
    223         case SpvOpTypeFloat:
    224             return SkString("TypeFloat");
    225         case SpvOpTypeVector:
    226             return SkString("TypeVector");
    227         case SpvOpTypeMatrix:
    228             return SkString("TypeMatrix");
    229         case SpvOpTypeImage:
    230             return SkString("TypeImage");
    231         case SpvOpTypeSampler:
    232             return SkString("TypeSampler");
    233         case SpvOpTypeSampledImage:
    234             return SkString("TypeSampledImage");
    235         case SpvOpTypeArray:
    236             return SkString("TypeArray");
    237         case SpvOpTypeRuntimeArray:
    238             return SkString("TypeRuntimeArray");
    239         case SpvOpTypeStruct:
    240             return SkString("TypeStruct");
    241         case SpvOpTypeOpaque:
    242             return SkString("TypeOpaque");
    243         case SpvOpTypePointer:
    244             return SkString("TypePointer");
    245         case SpvOpTypeFunction:
    246             return SkString("TypeFunction");
    247         case SpvOpTypeEvent:
    248             return SkString("TypeEvent");
    249         case SpvOpTypeDeviceEvent:
    250             return SkString("TypeDeviceEvent");
    251         case SpvOpTypeReserveId:
    252             return SkString("TypeReserveId");
    253         case SpvOpTypeQueue:
    254             return SkString("TypeQueue");
    255         case SpvOpTypePipe:
    256             return SkString("TypePipe");
    257         case SpvOpTypeForwardPointer:
    258             return SkString("TypeForwardPointer");
    259         case SpvOpConstantTrue:
    260             return SkString("ConstantTrue");
    261         case SpvOpConstantFalse:
    262             return SkString("ConstantFalse");
    263         case SpvOpConstant:
    264             return SkString("Constant");
    265         case SpvOpConstantComposite:
    266             return SkString("ConstantComposite");
    267         case SpvOpConstantSampler:
    268             return SkString("ConstantSampler");
    269         case SpvOpConstantNull:
    270             return SkString("ConstantNull");
    271         case SpvOpSpecConstantTrue:
    272             return SkString("SpecConstantTrue");
    273         case SpvOpSpecConstantFalse:
    274             return SkString("SpecConstantFalse");
    275         case SpvOpSpecConstant:
    276             return SkString("SpecConstant");
    277         case SpvOpSpecConstantComposite:
    278             return SkString("SpecConstantComposite");
    279         case SpvOpSpecConstantOp:
    280             return SkString("SpecConstantOp");
    281         case SpvOpFunction:
    282             return SkString("Function");
    283         case SpvOpFunctionParameter:
    284             return SkString("FunctionParameter");
    285         case SpvOpFunctionEnd:
    286             return SkString("FunctionEnd");
    287         case SpvOpFunctionCall:
    288             return SkString("FunctionCall");
    289         case SpvOpVariable:
    290             return SkString("Variable");
    291         case SpvOpImageTexelPointer:
    292             return SkString("ImageTexelPointer");
    293         case SpvOpLoad:
    294             return SkString("Load");
    295         case SpvOpStore:
    296             return SkString("Store");
    297         case SpvOpCopyMemory:
    298             return SkString("CopyMemory");
    299         case SpvOpCopyMemorySized:
    300             return SkString("CopyMemorySized");
    301         case SpvOpAccessChain:
    302             return SkString("AccessChain");
    303         case SpvOpInBoundsAccessChain:
    304             return SkString("InBoundsAccessChain");
    305         case SpvOpPtrAccessChain:
    306             return SkString("PtrAccessChain");
    307         case SpvOpArrayLength:
    308             return SkString("ArrayLength");
    309         case SpvOpGenericPtrMemSemantics:
    310             return SkString("GenericPtrMemSemantics");
    311         case SpvOpInBoundsPtrAccessChain:
    312             return SkString("InBoundsPtrAccessChain");
    313         case SpvOpDecorate:
    314             return SkString("Decorate");
    315         case SpvOpMemberDecorate:
    316             return SkString("MemberDecorate");
    317         case SpvOpDecorationGroup:
    318             return SkString("DecorationGroup");
    319         case SpvOpGroupDecorate:
    320             return SkString("GroupDecorate");
    321         case SpvOpGroupMemberDecorate:
    322             return SkString("GroupMemberDecorate");
    323         case SpvOpVectorExtractDynamic:
    324             return SkString("VectorExtractDynamic");
    325         case SpvOpVectorInsertDynamic:
    326             return SkString("VectorInsertDynamic");
    327         case SpvOpVectorShuffle:
    328             return SkString("VectorShuffle");
    329         case SpvOpCompositeConstruct:
    330             return SkString("CompositeConstruct");
    331         case SpvOpCompositeExtract:
    332             return SkString("CompositeExtract");
    333         case SpvOpCompositeInsert:
    334             return SkString("CompositeInsert");
    335         case SpvOpCopyObject:
    336             return SkString("CopyObject");
    337         case SpvOpTranspose:
    338             return SkString("Transpose");
    339         case SpvOpSampledImage:
    340             return SkString("SampledImage");
    341         case SpvOpImageSampleImplicitLod:
    342             return SkString("ImageSampleImplicitLod");
    343         case SpvOpImageSampleExplicitLod:
    344             return SkString("ImageSampleExplicitLod");
    345         case SpvOpImageSampleDrefImplicitLod:
    346             return SkString("ImageSampleDrefImplicitLod");
    347         case SpvOpImageSampleDrefExplicitLod:
    348             return SkString("ImageSampleDrefExplicitLod");
    349         case SpvOpImageSampleProjImplicitLod:
    350             return SkString("ImageSampleProjImplicitLod");
    351         case SpvOpImageSampleProjExplicitLod:
    352             return SkString("ImageSampleProjExplicitLod");
    353         case SpvOpImageSampleProjDrefImplicitLod:
    354             return SkString("ImageSampleProjDrefImplicitLod");
    355         case SpvOpImageSampleProjDrefExplicitLod:
    356             return SkString("ImageSampleProjDrefExplicitLod");
    357         case SpvOpImageFetch:
    358             return SkString("ImageFetch");
    359         case SpvOpImageGather:
    360             return SkString("ImageGather");
    361         case SpvOpImageDrefGather:
    362             return SkString("ImageDrefGather");
    363         case SpvOpImageRead:
    364             return SkString("ImageRead");
    365         case SpvOpImageWrite:
    366             return SkString("ImageWrite");
    367         case SpvOpImage:
    368             return SkString("Image");
    369         case SpvOpImageQueryFormat:
    370             return SkString("ImageQueryFormat");
    371         case SpvOpImageQueryOrder:
    372             return SkString("ImageQueryOrder");
    373         case SpvOpImageQuerySizeLod:
    374             return SkString("ImageQuerySizeLod");
    375         case SpvOpImageQuerySize:
    376             return SkString("ImageQuerySize");
    377         case SpvOpImageQueryLod:
    378             return SkString("ImageQueryLod");
    379         case SpvOpImageQueryLevels:
    380             return SkString("ImageQueryLevels");
    381         case SpvOpImageQuerySamples:
    382             return SkString("ImageQuerySamples");
    383         case SpvOpConvertFToU:
    384             return SkString("ConvertFToU");
    385         case SpvOpConvertFToS:
    386             return SkString("ConvertFToS");
    387         case SpvOpConvertSToF:
    388             return SkString("ConvertSToF");
    389         case SpvOpConvertUToF:
    390             return SkString("ConvertUToF");
    391         case SpvOpUConvert:
    392             return SkString("UConvert");
    393         case SpvOpSConvert:
    394             return SkString("SConvert");
    395         case SpvOpFConvert:
    396             return SkString("FConvert");
    397         case SpvOpQuantizeToF16:
    398             return SkString("QuantizeToF16");
    399         case SpvOpConvertPtrToU:
    400             return SkString("ConvertPtrToU");
    401         case SpvOpSatConvertSToU:
    402             return SkString("SatConvertSToU");
    403         case SpvOpSatConvertUToS:
    404             return SkString("SatConvertUToS");
    405         case SpvOpConvertUToPtr:
    406             return SkString("ConvertUToPtr");
    407         case SpvOpPtrCastToGeneric:
    408             return SkString("PtrCastToGeneric");
    409         case SpvOpGenericCastToPtr:
    410             return SkString("GenericCastToPtr");
    411         case SpvOpGenericCastToPtrExplicit:
    412             return SkString("GenericCastToPtrExplicit");
    413         case SpvOpBitcast:
    414             return SkString("Bitcast");
    415         case SpvOpSNegate:
    416             return SkString("SNegate");
    417         case SpvOpFNegate:
    418             return SkString("FNegate");
    419         case SpvOpIAdd:
    420             return SkString("IAdd");
    421         case SpvOpFAdd:
    422             return SkString("FAdd");
    423         case SpvOpISub:
    424             return SkString("ISub");
    425         case SpvOpFSub:
    426             return SkString("FSub");
    427         case SpvOpIMul:
    428             return SkString("IMul");
    429         case SpvOpFMul:
    430             return SkString("FMul");
    431         case SpvOpUDiv:
    432             return SkString("UDiv");
    433         case SpvOpSDiv:
    434             return SkString("SDiv");
    435         case SpvOpFDiv:
    436             return SkString("FDiv");
    437         case SpvOpUMod:
    438             return SkString("UMod");
    439         case SpvOpSRem:
    440             return SkString("SRem");
    441         case SpvOpSMod:
    442             return SkString("SMod");
    443         case SpvOpFRem:
    444             return SkString("FRem");
    445         case SpvOpFMod:
    446             return SkString("FMod");
    447         case SpvOpVectorTimesScalar:
    448             return SkString("VectorTimesScalar");
    449         case SpvOpMatrixTimesScalar:
    450             return SkString("MatrixTimesScalar");
    451         case SpvOpVectorTimesMatrix:
    452             return SkString("VectorTimesMatrix");
    453         case SpvOpMatrixTimesVector:
    454             return SkString("MatrixTimesVector");
    455         case SpvOpMatrixTimesMatrix:
    456             return SkString("MatrixTimesMatrix");
    457         case SpvOpOuterProduct:
    458             return SkString("OuterProduct");
    459         case SpvOpDot:
    460             return SkString("Dot");
    461         case SpvOpIAddCarry:
    462             return SkString("IAddCarry");
    463         case SpvOpISubBorrow:
    464             return SkString("ISubBorrow");
    465         case SpvOpUMulExtended:
    466             return SkString("UMulExtended");
    467         case SpvOpSMulExtended:
    468             return SkString("SMulExtended");
    469         case SpvOpAny:
    470             return SkString("Any");
    471         case SpvOpAll:
    472             return SkString("All");
    473         case SpvOpIsNan:
    474             return SkString("IsNan");
    475         case SpvOpIsInf:
    476             return SkString("IsInf");
    477         case SpvOpIsFinite:
    478             return SkString("IsFinite");
    479         case SpvOpIsNormal:
    480             return SkString("IsNormal");
    481         case SpvOpSignBitSet:
    482             return SkString("SignBitSet");
    483         case SpvOpLessOrGreater:
    484             return SkString("LessOrGreater");
    485         case SpvOpOrdered:
    486             return SkString("Ordered");
    487         case SpvOpUnordered:
    488             return SkString("Unordered");
    489         case SpvOpLogicalEqual:
    490             return SkString("LogicalEqual");
    491         case SpvOpLogicalNotEqual:
    492             return SkString("LogicalNotEqual");
    493         case SpvOpLogicalOr:
    494             return SkString("LogicalOr");
    495         case SpvOpLogicalAnd:
    496             return SkString("LogicalAnd");
    497         case SpvOpLogicalNot:
    498             return SkString("LogicalNot");
    499         case SpvOpSelect:
    500             return SkString("Select");
    501         case SpvOpIEqual:
    502             return SkString("IEqual");
    503         case SpvOpINotEqual:
    504             return SkString("INotEqual");
    505         case SpvOpUGreaterThan:
    506             return SkString("UGreaterThan");
    507         case SpvOpSGreaterThan:
    508             return SkString("SGreaterThan");
    509         case SpvOpUGreaterThanEqual:
    510             return SkString("UGreaterThanEqual");
    511         case SpvOpSGreaterThanEqual:
    512             return SkString("SGreaterThanEqual");
    513         case SpvOpULessThan:
    514             return SkString("ULessThan");
    515         case SpvOpSLessThan:
    516             return SkString("SLessThan");
    517         case SpvOpULessThanEqual:
    518             return SkString("ULessThanEqual");
    519         case SpvOpSLessThanEqual:
    520             return SkString("SLessThanEqual");
    521         case SpvOpFOrdEqual:
    522             return SkString("FOrdEqual");
    523         case SpvOpFUnordEqual:
    524             return SkString("FUnordEqual");
    525         case SpvOpFOrdNotEqual:
    526             return SkString("FOrdNotEqual");
    527         case SpvOpFUnordNotEqual:
    528             return SkString("FUnordNotEqual");
    529         case SpvOpFOrdLessThan:
    530             return SkString("FOrdLessThan");
    531         case SpvOpFUnordLessThan:
    532             return SkString("FUnordLessThan");
    533         case SpvOpFOrdGreaterThan:
    534             return SkString("FOrdGreaterThan");
    535         case SpvOpFUnordGreaterThan:
    536             return SkString("FUnordGreaterThan");
    537         case SpvOpFOrdLessThanEqual:
    538             return SkString("FOrdLessThanEqual");
    539         case SpvOpFUnordLessThanEqual:
    540             return SkString("FUnordLessThanEqual");
    541         case SpvOpFOrdGreaterThanEqual:
    542             return SkString("FOrdGreaterThanEqual");
    543         case SpvOpFUnordGreaterThanEqual:
    544             return SkString("FUnordGreaterThanEqual");
    545         case SpvOpShiftRightLogical:
    546             return SkString("ShiftRightLogical");
    547         case SpvOpShiftRightArithmetic:
    548             return SkString("ShiftRightArithmetic");
    549         case SpvOpShiftLeftLogical:
    550             return SkString("ShiftLeftLogical");
    551         case SpvOpBitwiseOr:
    552             return SkString("BitwiseOr");
    553         case SpvOpBitwiseXor:
    554             return SkString("BitwiseXor");
    555         case SpvOpBitwiseAnd:
    556             return SkString("BitwiseAnd");
    557         case SpvOpNot:
    558             return SkString("Not");
    559         case SpvOpBitFieldInsert:
    560             return SkString("BitFieldInsert");
    561         case SpvOpBitFieldSExtract:
    562             return SkString("BitFieldSExtract");
    563         case SpvOpBitFieldUExtract:
    564             return SkString("BitFieldUExtract");
    565         case SpvOpBitReverse:
    566             return SkString("BitReverse");
    567         case SpvOpBitCount:
    568             return SkString("BitCount");
    569         case SpvOpDPdx:
    570             return SkString("DPdx");
    571         case SpvOpDPdy:
    572             return SkString("DPdy");
    573         case SpvOpFwidth:
    574             return SkString("Fwidth");
    575         case SpvOpDPdxFine:
    576             return SkString("DPdxFine");
    577         case SpvOpDPdyFine:
    578             return SkString("DPdyFine");
    579         case SpvOpFwidthFine:
    580             return SkString("FwidthFine");
    581         case SpvOpDPdxCoarse:
    582             return SkString("DPdxCoarse");
    583         case SpvOpDPdyCoarse:
    584             return SkString("DPdyCoarse");
    585         case SpvOpFwidthCoarse:
    586             return SkString("FwidthCoarse");
    587         case SpvOpEmitVertex:
    588             return SkString("EmitVertex");
    589         case SpvOpEndPrimitive:
    590             return SkString("EndPrimitive");
    591         case SpvOpEmitStreamVertex:
    592             return SkString("EmitStreamVertex");
    593         case SpvOpEndStreamPrimitive:
    594             return SkString("EndStreamPrimitive");
    595         case SpvOpControlBarrier:
    596             return SkString("ControlBarrier");
    597         case SpvOpMemoryBarrier:
    598             return SkString("MemoryBarrier");
    599         case SpvOpAtomicLoad:
    600             return SkString("AtomicLoad");
    601         case SpvOpAtomicStore:
    602             return SkString("AtomicStore");
    603         case SpvOpAtomicExchange:
    604             return SkString("AtomicExchange");
    605         case SpvOpAtomicCompareExchange:
    606             return SkString("AtomicCompareExchange");
    607         case SpvOpAtomicCompareExchangeWeak:
    608             return SkString("AtomicCompareExchangeWeak");
    609         case SpvOpAtomicIIncrement:
    610             return SkString("AtomicIIncrement");
    611         case SpvOpAtomicIDecrement:
    612             return SkString("AtomicIDecrement");
    613         case SpvOpAtomicIAdd:
    614             return SkString("AtomicIAdd");
    615         case SpvOpAtomicISub:
    616             return SkString("AtomicISub");
    617         case SpvOpAtomicSMin:
    618             return SkString("AtomicSMin");
    619         case SpvOpAtomicUMin:
    620             return SkString("AtomicUMin");
    621         case SpvOpAtomicSMax:
    622             return SkString("AtomicSMax");
    623         case SpvOpAtomicUMax:
    624             return SkString("AtomicUMax");
    625         case SpvOpAtomicAnd:
    626             return SkString("AtomicAnd");
    627         case SpvOpAtomicOr:
    628             return SkString("AtomicOr");
    629         case SpvOpAtomicXor:
    630             return SkString("AtomicXor");
    631         case SpvOpPhi:
    632             return SkString("Phi");
    633         case SpvOpLoopMerge:
    634             return SkString("LoopMerge");
    635         case SpvOpSelectionMerge:
    636             return SkString("SelectionMerge");
    637         case SpvOpLabel:
    638             return SkString("Label");
    639         case SpvOpBranch:
    640             return SkString("Branch");
    641         case SpvOpBranchConditional:
    642             return SkString("BranchConditional");
    643         case SpvOpSwitch:
    644             return SkString("Switch");
    645         case SpvOpKill:
    646             return SkString("Kill");
    647         case SpvOpReturn:
    648             return SkString("Return");
    649         case SpvOpReturnValue:
    650             return SkString("ReturnValue");
    651         case SpvOpUnreachable:
    652             return SkString("Unreachable");
    653         case SpvOpLifetimeStart:
    654             return SkString("LifetimeStart");
    655         case SpvOpLifetimeStop:
    656             return SkString("LifetimeStop");
    657         case SpvOpGroupAsyncCopy:
    658             return SkString("GroupAsyncCopy");
    659         case SpvOpGroupWaitEvents:
    660             return SkString("GroupWaitEvents");
    661         case SpvOpGroupAll:
    662             return SkString("GroupAll");
    663         case SpvOpGroupAny:
    664             return SkString("GroupAny");
    665         case SpvOpGroupBroadcast:
    666             return SkString("GroupBroadcast");
    667         case SpvOpGroupIAdd:
    668             return SkString("GroupIAdd");
    669         case SpvOpGroupFAdd:
    670             return SkString("GroupFAdd");
    671         case SpvOpGroupFMin:
    672             return SkString("GroupFMin");
    673         case SpvOpGroupUMin:
    674             return SkString("GroupUMin");
    675         case SpvOpGroupSMin:
    676             return SkString("GroupSMin");
    677         case SpvOpGroupFMax:
    678             return SkString("GroupFMax");
    679         case SpvOpGroupUMax:
    680             return SkString("GroupUMax");
    681         case SpvOpGroupSMax:
    682             return SkString("GroupSMax");
    683         case SpvOpReadPipe:
    684             return SkString("ReadPipe");
    685         case SpvOpWritePipe:
    686             return SkString("WritePipe");
    687         case SpvOpReservedReadPipe:
    688             return SkString("ReservedReadPipe");
    689         case SpvOpReservedWritePipe:
    690             return SkString("ReservedWritePipe");
    691         case SpvOpReserveReadPipePackets:
    692             return SkString("ReserveReadPipePackets");
    693         case SpvOpReserveWritePipePackets:
    694             return SkString("ReserveWritePipePackets");
    695         case SpvOpCommitReadPipe:
    696             return SkString("CommitReadPipe");
    697         case SpvOpCommitWritePipe:
    698             return SkString("CommitWritePipe");
    699         case SpvOpIsValidReserveId:
    700             return SkString("IsValidReserveId");
    701         case SpvOpGetNumPipePackets:
    702             return SkString("GetNumPipePackets");
    703         case SpvOpGetMaxPipePackets:
    704             return SkString("GetMaxPipePackets");
    705         case SpvOpGroupReserveReadPipePackets:
    706             return SkString("GroupReserveReadPipePackets");
    707         case SpvOpGroupReserveWritePipePackets:
    708             return SkString("GroupReserveWritePipePackets");
    709         case SpvOpGroupCommitReadPipe:
    710             return SkString("GroupCommitReadPipe");
    711         case SpvOpGroupCommitWritePipe:
    712             return SkString("GroupCommitWritePipe");
    713         case SpvOpEnqueueMarker:
    714             return SkString("EnqueueMarker");
    715         case SpvOpEnqueueKernel:
    716             return SkString("EnqueueKernel");
    717         case SpvOpGetKernelNDrangeSubGroupCount:
    718             return SkString("GetKernelNDrangeSubGroupCount");
    719         case SpvOpGetKernelNDrangeMaxSubGroupSize:
    720             return SkString("GetKernelNDrangeMaxSubGroupSize");
    721         case SpvOpGetKernelWorkGroupSize:
    722             return SkString("GetKernelWorkGroupSize");
    723         case SpvOpGetKernelPreferredWorkGroupSizeMultiple:
    724             return SkString("GetKernelPreferredWorkGroupSizeMultiple");
    725         case SpvOpRetainEvent:
    726             return SkString("RetainEvent");
    727         case SpvOpReleaseEvent:
    728             return SkString("ReleaseEvent");
    729         case SpvOpCreateUserEvent:
    730             return SkString("CreateUserEvent");
    731         case SpvOpIsValidEvent:
    732             return SkString("IsValidEvent");
    733         case SpvOpSetUserEventStatus:
    734             return SkString("SetUserEventStatus");
    735         case SpvOpCaptureEventProfilingInfo:
    736             return SkString("CaptureEventProfilingInfo");
    737         case SpvOpGetDefaultQueue:
    738             return SkString("GetDefaultQueue");
    739         case SpvOpBuildNDRange:
    740             return SkString("BuildNDRange");
    741         case SpvOpImageSparseSampleImplicitLod:
    742             return SkString("ImageSparseSampleImplicitLod");
    743         case SpvOpImageSparseSampleExplicitLod:
    744             return SkString("ImageSparseSampleExplicitLod");
    745         case SpvOpImageSparseSampleDrefImplicitLod:
    746             return SkString("ImageSparseSampleDrefImplicitLod");
    747         case SpvOpImageSparseSampleDrefExplicitLod:
    748             return SkString("ImageSparseSampleDrefExplicitLod");
    749         case SpvOpImageSparseSampleProjImplicitLod:
    750             return SkString("ImageSparseSampleProjImplicitLod");
    751         case SpvOpImageSparseSampleProjExplicitLod:
    752             return SkString("ImageSparseSampleProjExplicitLod");
    753         case SpvOpImageSparseSampleProjDrefImplicitLod:
    754             return SkString("ImageSparseSampleProjDrefImplicitLod");
    755         case SpvOpImageSparseSampleProjDrefExplicitLod:
    756             return SkString("ImageSparseSampleProjDrefExplicitLod");
    757         case SpvOpImageSparseFetch:
    758             return SkString("ImageSparseFetch");
    759         case SpvOpImageSparseGather:
    760             return SkString("ImageSparseGather");
    761         case SpvOpImageSparseDrefGather:
    762             return SkString("ImageSparseDrefGather");
    763         case SpvOpImageSparseTexelsResident:
    764             return SkString("ImageSparseTexelsResident");
    765         case SpvOpNoLine:
    766             return SkString("NoLine");
    767         case SpvOpAtomicFlagTestAndSet:
    768             return SkString("AtomicFlagTestAndSet");
    769         case SpvOpAtomicFlagClear:
    770             return SkString("AtomicFlagClear");
    771         case SpvOpImageSparseRead:
    772             return SkString("ImageSparseRead");
    773         default:
    774             ABORT("unsupported SPIR-V op");
    775     }
    776 }
    777 #endif
    778 
    779 void SPIRVCodeGenerator::writeOpCode(SpvOp_ opCode, int length, SkWStream& out) {
    780     ASSERT(opCode != SpvOpUndef);
    781     switch (opCode) {
    782         case SpvOpReturn:      // fall through
    783         case SpvOpReturnValue: // fall through
    784         case SpvOpKill:        // fall through
    785         case SpvOpBranch:      // fall through
    786         case SpvOpBranchConditional:
    787             ASSERT(fCurrentBlock);
    788             fCurrentBlock = 0;
    789             break;
    790         case SpvOpConstant:          // fall through
    791         case SpvOpConstantTrue:      // fall through
    792         case SpvOpConstantFalse:     // fall through
    793         case SpvOpConstantComposite: // fall through
    794         case SpvOpTypeVoid:          // fall through
    795         case SpvOpTypeInt:           // fall through
    796         case SpvOpTypeFloat:         // fall through
    797         case SpvOpTypeBool:          // fall through
    798         case SpvOpTypeVector:        // fall through
    799         case SpvOpTypeMatrix:        // fall through
    800         case SpvOpTypeArray:         // fall through
    801         case SpvOpTypePointer:       // fall through
    802         case SpvOpTypeFunction:      // fall through
    803         case SpvOpTypeRuntimeArray:  // fall through
    804         case SpvOpTypeStruct:        // fall through
    805         case SpvOpTypeImage:         // fall through
    806         case SpvOpTypeSampledImage:  // fall through
    807         case SpvOpVariable:          // fall through
    808         case SpvOpFunction:          // fall through
    809         case SpvOpFunctionParameter: // fall through
    810         case SpvOpFunctionEnd:       // fall through
    811         case SpvOpExecutionMode:     // fall through
    812         case SpvOpMemoryModel:       // fall through
    813         case SpvOpCapability:        // fall through
    814         case SpvOpExtInstImport:     // fall through
    815         case SpvOpEntryPoint:        // fall through
    816         case SpvOpSource:            // fall through
    817         case SpvOpSourceExtension:   // fall through
    818         case SpvOpName:              // fall through
    819         case SpvOpMemberName:        // fall through
    820         case SpvOpDecorate:          // fall through
    821         case SpvOpMemberDecorate:
    822             break;
    823         default:
    824             ASSERT(fCurrentBlock);
    825     }
    826 #if SPIRV_DEBUG
    827     out << std::endl << opcode_text(opCode) << " ";
    828 #else
    829     this->writeWord((length << 16) | opCode, out);
    830 #endif
    831 }
    832 
    833 void SPIRVCodeGenerator::writeLabel(SpvId label, SkWStream& out) {
    834     fCurrentBlock = label;
    835     this->writeInstruction(SpvOpLabel, label, out);
    836 }
    837 
    838 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, SkWStream& out) {
    839     this->writeOpCode(opCode, 1, out);
    840 }
    841 
    842 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, SkWStream& out) {
    843     this->writeOpCode(opCode, 2, out);
    844     this->writeWord(word1, out);
    845 }
    846 
    847 void SPIRVCodeGenerator::writeString(const char* string, SkWStream& out) {
    848     size_t length = strlen(string);
    849     out.writeText(string);
    850     switch (length % 4) {
    851         case 1:
    852             out.write8(0);
    853             // fall through
    854         case 2:
    855             out.write8(0);
    856             // fall through
    857         case 3:
    858             out.write8(0);
    859             break;
    860         default:
    861             this->writeWord(0, out);
    862     }
    863 }
    864 
    865 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, const char* string, SkWStream& out) {
    866     int32_t length = (int32_t) strlen(string);
    867     this->writeOpCode(opCode, 1 + (length + 4) / 4, out);
    868     this->writeString(string, out);
    869 }
    870 
    871 
    872 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, const char* string,
    873                                           SkWStream& out) {
    874     int32_t length = (int32_t) strlen(string);
    875     this->writeOpCode(opCode, 2 + (length + 4) / 4, out);
    876     this->writeWord(word1, out);
    877     this->writeString(string, out);
    878 }
    879 
    880 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
    881                                           const char* string, SkWStream& out) {
    882     int32_t length = (int32_t) strlen(string);
    883     this->writeOpCode(opCode, 3 + (length + 4) / 4, out);
    884     this->writeWord(word1, out);
    885     this->writeWord(word2, out);
    886     this->writeString(string, out);
    887 }
    888 
    889 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
    890                                           SkWStream& out) {
    891     this->writeOpCode(opCode, 3, out);
    892     this->writeWord(word1, out);
    893     this->writeWord(word2, out);
    894 }
    895 
    896 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
    897                                           int32_t word3, SkWStream& out) {
    898     this->writeOpCode(opCode, 4, out);
    899     this->writeWord(word1, out);
    900     this->writeWord(word2, out);
    901     this->writeWord(word3, out);
    902 }
    903 
    904 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
    905                                           int32_t word3, int32_t word4, SkWStream& out) {
    906     this->writeOpCode(opCode, 5, out);
    907     this->writeWord(word1, out);
    908     this->writeWord(word2, out);
    909     this->writeWord(word3, out);
    910     this->writeWord(word4, out);
    911 }
    912 
    913 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
    914                                           int32_t word3, int32_t word4, int32_t word5,
    915                                           SkWStream& out) {
    916     this->writeOpCode(opCode, 6, out);
    917     this->writeWord(word1, out);
    918     this->writeWord(word2, out);
    919     this->writeWord(word3, out);
    920     this->writeWord(word4, out);
    921     this->writeWord(word5, out);
    922 }
    923 
    924 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
    925                                           int32_t word3, int32_t word4, int32_t word5,
    926                                           int32_t word6, SkWStream& out) {
    927     this->writeOpCode(opCode, 7, out);
    928     this->writeWord(word1, out);
    929     this->writeWord(word2, out);
    930     this->writeWord(word3, out);
    931     this->writeWord(word4, out);
    932     this->writeWord(word5, out);
    933     this->writeWord(word6, out);
    934 }
    935 
    936 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
    937                                           int32_t word3, int32_t word4, int32_t word5,
    938                                           int32_t word6, int32_t word7, SkWStream& out) {
    939     this->writeOpCode(opCode, 8, out);
    940     this->writeWord(word1, out);
    941     this->writeWord(word2, out);
    942     this->writeWord(word3, out);
    943     this->writeWord(word4, out);
    944     this->writeWord(word5, out);
    945     this->writeWord(word6, out);
    946     this->writeWord(word7, out);
    947 }
    948 
    949 void SPIRVCodeGenerator::writeInstruction(SpvOp_ opCode, int32_t word1, int32_t word2,
    950                                           int32_t word3, int32_t word4, int32_t word5,
    951                                           int32_t word6, int32_t word7, int32_t word8,
    952                                           SkWStream& out) {
    953     this->writeOpCode(opCode, 9, out);
    954     this->writeWord(word1, out);
    955     this->writeWord(word2, out);
    956     this->writeWord(word3, out);
    957     this->writeWord(word4, out);
    958     this->writeWord(word5, out);
    959     this->writeWord(word6, out);
    960     this->writeWord(word7, out);
    961     this->writeWord(word8, out);
    962 }
    963 
    964 void SPIRVCodeGenerator::writeCapabilities(SkWStream& out) {
    965     for (uint64_t i = 0, bit = 1; i <= kLast_Capability; i++, bit <<= 1) {
    966         if (fCapabilities & bit) {
    967             this->writeInstruction(SpvOpCapability, (SpvId) i, out);
    968         }
    969     }
    970 }
    971 
    972 SpvId SPIRVCodeGenerator::nextId() {
    973     return fIdCount++;
    974 }
    975 
    976 void SPIRVCodeGenerator::writeStruct(const Type& type, const MemoryLayout& memoryLayout,
    977                                      SpvId resultId) {
    978     this->writeInstruction(SpvOpName, resultId, type.name().c_str(), fNameBuffer);
    979     // go ahead and write all of the field types, so we don't inadvertently write them while we're
    980     // in the middle of writing the struct instruction
    981     std::vector<SpvId> types;
    982     for (const auto& f : type.fields()) {
    983         types.push_back(this->getType(*f.fType, memoryLayout));
    984     }
    985     this->writeOpCode(SpvOpTypeStruct, 2 + (int32_t) types.size(), fConstantBuffer);
    986     this->writeWord(resultId, fConstantBuffer);
    987     for (SpvId id : types) {
    988         this->writeWord(id, fConstantBuffer);
    989     }
    990     size_t offset = 0;
    991     for (int32_t i = 0; i < (int32_t) type.fields().size(); i++) {
    992         size_t size = memoryLayout.size(*type.fields()[i].fType);
    993         size_t alignment = memoryLayout.alignment(*type.fields()[i].fType);
    994         const Layout& fieldLayout = type.fields()[i].fModifiers.fLayout;
    995         if (fieldLayout.fOffset >= 0) {
    996             if (fieldLayout.fOffset < (int) offset) {
    997                 fErrors.error(type.fPosition,
    998                               "offset of field '" + type.fields()[i].fName + "' must be at "
    999                               "least " + to_string((int) offset));
   1000             }
   1001             if (fieldLayout.fOffset % alignment) {
   1002                 fErrors.error(type.fPosition,
   1003                               "offset of field '" + type.fields()[i].fName + "' must be a multiple"
   1004                               " of " + to_string((int) alignment));
   1005             }
   1006             offset = fieldLayout.fOffset;
   1007         } else {
   1008             size_t mod = offset % alignment;
   1009             if (mod) {
   1010                 offset += alignment - mod;
   1011             }
   1012         }
   1013         this->writeInstruction(SpvOpMemberName, resultId, i, type.fields()[i].fName.c_str(),
   1014                                fNameBuffer);
   1015         this->writeLayout(fieldLayout, resultId, i);
   1016         if (type.fields()[i].fModifiers.fLayout.fBuiltin < 0) {
   1017             this->writeInstruction(SpvOpMemberDecorate, resultId, (SpvId) i, SpvDecorationOffset,
   1018                                    (SpvId) offset, fDecorationBuffer);
   1019         }
   1020         if (type.fields()[i].fType->kind() == Type::kMatrix_Kind) {
   1021             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationColMajor,
   1022                                    fDecorationBuffer);
   1023             this->writeInstruction(SpvOpMemberDecorate, resultId, i, SpvDecorationMatrixStride,
   1024                                    (SpvId) memoryLayout.stride(*type.fields()[i].fType),
   1025                                    fDecorationBuffer);
   1026         }
   1027         offset += size;
   1028         Type::Kind kind = type.fields()[i].fType->kind();
   1029         if ((kind == Type::kArray_Kind || kind == Type::kStruct_Kind) && offset % alignment != 0) {
   1030             offset += alignment - offset % alignment;
   1031         }
   1032     }
   1033 }
   1034 
   1035 SpvId SPIRVCodeGenerator::getType(const Type& type) {
   1036     return this->getType(type, fDefaultLayout);
   1037 }
   1038 
   1039 SpvId SPIRVCodeGenerator::getType(const Type& type, const MemoryLayout& layout) {
   1040     SkString key = type.name() + to_string((int) layout.fStd);
   1041     auto entry = fTypeMap.find(key);
   1042     if (entry == fTypeMap.end()) {
   1043         SpvId result = this->nextId();
   1044         switch (type.kind()) {
   1045             case Type::kScalar_Kind:
   1046                 if (type == *fContext.fBool_Type) {
   1047                     this->writeInstruction(SpvOpTypeBool, result, fConstantBuffer);
   1048                 } else if (type == *fContext.fInt_Type) {
   1049                     this->writeInstruction(SpvOpTypeInt, result, 32, 1, fConstantBuffer);
   1050                 } else if (type == *fContext.fUInt_Type) {
   1051                     this->writeInstruction(SpvOpTypeInt, result, 32, 0, fConstantBuffer);
   1052                 } else if (type == *fContext.fFloat_Type) {
   1053                     this->writeInstruction(SpvOpTypeFloat, result, 32, fConstantBuffer);
   1054                 } else if (type == *fContext.fDouble_Type) {
   1055                     this->writeInstruction(SpvOpTypeFloat, result, 64, fConstantBuffer);
   1056                 } else {
   1057                     ASSERT(false);
   1058                 }
   1059                 break;
   1060             case Type::kVector_Kind:
   1061                 this->writeInstruction(SpvOpTypeVector, result,
   1062                                        this->getType(type.componentType(), layout),
   1063                                        type.columns(), fConstantBuffer);
   1064                 break;
   1065             case Type::kMatrix_Kind:
   1066                 this->writeInstruction(SpvOpTypeMatrix, result,
   1067                                        this->getType(index_type(fContext, type), layout),
   1068                                        type.columns(), fConstantBuffer);
   1069                 break;
   1070             case Type::kStruct_Kind:
   1071                 this->writeStruct(type, layout, result);
   1072                 break;
   1073             case Type::kArray_Kind: {
   1074                 if (type.columns() > 0) {
   1075                     IntLiteral count(fContext, Position(), type.columns());
   1076                     this->writeInstruction(SpvOpTypeArray, result,
   1077                                            this->getType(type.componentType(), layout),
   1078                                            this->writeIntLiteral(count), fConstantBuffer);
   1079                     this->writeInstruction(SpvOpDecorate, result, SpvDecorationArrayStride,
   1080                                            (int32_t) layout.stride(type),
   1081                                            fDecorationBuffer);
   1082                 } else {
   1083                     ABORT("runtime-sized arrays are not yet supported");
   1084                     this->writeInstruction(SpvOpTypeRuntimeArray, result,
   1085                                            this->getType(type.componentType(), layout),
   1086                                            fConstantBuffer);
   1087                 }
   1088                 break;
   1089             }
   1090             case Type::kSampler_Kind: {
   1091                 SpvId image = result;
   1092                 if (SpvDimSubpassData != type.dimensions()) {
   1093                     image = this->nextId();
   1094                 }
   1095                 this->writeInstruction(SpvOpTypeImage, image,
   1096                                        this->getType(*fContext.fFloat_Type, layout),
   1097                                        type.dimensions(), type.isDepth(), type.isArrayed(),
   1098                                        type.isMultisampled(), type.isSampled() ? 1 : 2,
   1099                                        SpvImageFormatUnknown, fConstantBuffer);
   1100                 if (SpvDimSubpassData != type.dimensions()) {
   1101                     this->writeInstruction(SpvOpTypeSampledImage, result, image, fConstantBuffer);
   1102                 }
   1103                 break;
   1104             }
   1105             default:
   1106                 if (type == *fContext.fVoid_Type) {
   1107                     this->writeInstruction(SpvOpTypeVoid, result, fConstantBuffer);
   1108                 } else {
   1109                     ABORT("invalid type: %s", type.description().c_str());
   1110                 }
   1111         }
   1112         fTypeMap[key] = result;
   1113         return result;
   1114     }
   1115     return entry->second;
   1116 }
   1117 
   1118 SpvId SPIRVCodeGenerator::getFunctionType(const FunctionDeclaration& function) {
   1119     SkString key = function.fReturnType.description() + "(";
   1120     SkString separator;
   1121     for (size_t i = 0; i < function.fParameters.size(); i++) {
   1122         key += separator;
   1123         separator = ", ";
   1124         key += function.fParameters[i]->fType.description();
   1125     }
   1126     key += ")";
   1127     auto entry = fTypeMap.find(key);
   1128     if (entry == fTypeMap.end()) {
   1129         SpvId result = this->nextId();
   1130         int32_t length = 3 + (int32_t) function.fParameters.size();
   1131         SpvId returnType = this->getType(function.fReturnType);
   1132         std::vector<SpvId> parameterTypes;
   1133         for (size_t i = 0; i < function.fParameters.size(); i++) {
   1134             // glslang seems to treat all function arguments as pointers whether they need to be or
   1135             // not. I  was initially puzzled by this until I ran bizarre failures with certain
   1136             // patterns of function calls and control constructs, as exemplified by this minimal
   1137             // failure case:
   1138             //
   1139             // void sphere(float x) {
   1140             // }
   1141             //
   1142             // void map() {
   1143             //     sphere(1.0);
   1144             // }
   1145             //
   1146             // void main() {
   1147             //     for (int i = 0; i < 1; i++) {
   1148             //         map();
   1149             //     }
   1150             // }
   1151             //
   1152             // As of this writing, compiling this in the "obvious" way (with sphere taking a float)
   1153             // crashes. Making it take a float* and storing the argument in a temporary variable,
   1154             // as glslang does, fixes it. It's entirely possible I simply missed whichever part of
   1155             // the spec makes this make sense.
   1156 //            if (is_out(function->fParameters[i])) {
   1157                 parameterTypes.push_back(this->getPointerType(function.fParameters[i]->fType,
   1158                                                               SpvStorageClassFunction));
   1159 //            } else {
   1160 //                parameterTypes.push_back(this->getType(function.fParameters[i]->fType));
   1161 //            }
   1162         }
   1163         this->writeOpCode(SpvOpTypeFunction, length, fConstantBuffer);
   1164         this->writeWord(result, fConstantBuffer);
   1165         this->writeWord(returnType, fConstantBuffer);
   1166         for (SpvId id : parameterTypes) {
   1167             this->writeWord(id, fConstantBuffer);
   1168         }
   1169         fTypeMap[key] = result;
   1170         return result;
   1171     }
   1172     return entry->second;
   1173 }
   1174 
   1175 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, SpvStorageClass_ storageClass) {
   1176     return this->getPointerType(type, fDefaultLayout, storageClass);
   1177 }
   1178 
   1179 SpvId SPIRVCodeGenerator::getPointerType(const Type& type, const MemoryLayout& layout,
   1180                                          SpvStorageClass_ storageClass) {
   1181     SkString key = type.description() + "*" + to_string(layout.fStd) + to_string(storageClass);
   1182     auto entry = fTypeMap.find(key);
   1183     if (entry == fTypeMap.end()) {
   1184         SpvId result = this->nextId();
   1185         this->writeInstruction(SpvOpTypePointer, result, storageClass,
   1186                                this->getType(type), fConstantBuffer);
   1187         fTypeMap[key] = result;
   1188         return result;
   1189     }
   1190     return entry->second;
   1191 }
   1192 
   1193 SpvId SPIRVCodeGenerator::writeExpression(const Expression& expr, SkWStream& out) {
   1194     switch (expr.fKind) {
   1195         case Expression::kBinary_Kind:
   1196             return this->writeBinaryExpression((BinaryExpression&) expr, out);
   1197         case Expression::kBoolLiteral_Kind:
   1198             return this->writeBoolLiteral((BoolLiteral&) expr);
   1199         case Expression::kConstructor_Kind:
   1200             return this->writeConstructor((Constructor&) expr, out);
   1201         case Expression::kIntLiteral_Kind:
   1202             return this->writeIntLiteral((IntLiteral&) expr);
   1203         case Expression::kFieldAccess_Kind:
   1204             return this->writeFieldAccess(((FieldAccess&) expr), out);
   1205         case Expression::kFloatLiteral_Kind:
   1206             return this->writeFloatLiteral(((FloatLiteral&) expr));
   1207         case Expression::kFunctionCall_Kind:
   1208             return this->writeFunctionCall((FunctionCall&) expr, out);
   1209         case Expression::kPrefix_Kind:
   1210             return this->writePrefixExpression((PrefixExpression&) expr, out);
   1211         case Expression::kPostfix_Kind:
   1212             return this->writePostfixExpression((PostfixExpression&) expr, out);
   1213         case Expression::kSwizzle_Kind:
   1214             return this->writeSwizzle((Swizzle&) expr, out);
   1215         case Expression::kVariableReference_Kind:
   1216             return this->writeVariableReference((VariableReference&) expr, out);
   1217         case Expression::kTernary_Kind:
   1218             return this->writeTernaryExpression((TernaryExpression&) expr, out);
   1219         case Expression::kIndex_Kind:
   1220             return this->writeIndexExpression((IndexExpression&) expr, out);
   1221         default:
   1222             ABORT("unsupported expression: %s", expr.description().c_str());
   1223     }
   1224     return -1;
   1225 }
   1226 
   1227 SpvId SPIRVCodeGenerator::writeIntrinsicCall(const FunctionCall& c, SkWStream& out) {
   1228     auto intrinsic = fIntrinsicMap.find(c.fFunction.fName);
   1229     ASSERT(intrinsic != fIntrinsicMap.end());
   1230     const Type& type = c.fArguments[0]->fType;
   1231     int32_t intrinsicId;
   1232     if (std::get<0>(intrinsic->second) == kSpecial_IntrinsicKind || is_float(fContext, type)) {
   1233         intrinsicId = std::get<1>(intrinsic->second);
   1234     } else if (is_signed(fContext, type)) {
   1235         intrinsicId = std::get<2>(intrinsic->second);
   1236     } else if (is_unsigned(fContext, type)) {
   1237         intrinsicId = std::get<3>(intrinsic->second);
   1238     } else if (is_bool(fContext, type)) {
   1239         intrinsicId = std::get<4>(intrinsic->second);
   1240     } else {
   1241         ABORT("invalid call %s, cannot operate on '%s'", c.description().c_str(),
   1242               type.description().c_str());
   1243     }
   1244     switch (std::get<0>(intrinsic->second)) {
   1245         case kGLSL_STD_450_IntrinsicKind: {
   1246             SpvId result = this->nextId();
   1247             std::vector<SpvId> arguments;
   1248             for (size_t i = 0; i < c.fArguments.size(); i++) {
   1249                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
   1250             }
   1251             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
   1252             this->writeWord(this->getType(c.fType), out);
   1253             this->writeWord(result, out);
   1254             this->writeWord(fGLSLExtendedInstructions, out);
   1255             this->writeWord(intrinsicId, out);
   1256             for (SpvId id : arguments) {
   1257                 this->writeWord(id, out);
   1258             }
   1259             return result;
   1260         }
   1261         case kSPIRV_IntrinsicKind: {
   1262             SpvId result = this->nextId();
   1263             std::vector<SpvId> arguments;
   1264             for (size_t i = 0; i < c.fArguments.size(); i++) {
   1265                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
   1266             }
   1267             this->writeOpCode((SpvOp_) intrinsicId, 3 + (int32_t) arguments.size(), out);
   1268             this->writeWord(this->getType(c.fType), out);
   1269             this->writeWord(result, out);
   1270             for (SpvId id : arguments) {
   1271                 this->writeWord(id, out);
   1272             }
   1273             return result;
   1274         }
   1275         case kSpecial_IntrinsicKind:
   1276             return this->writeSpecialIntrinsic(c, (SpecialIntrinsic) intrinsicId, out);
   1277         default:
   1278             ABORT("unsupported intrinsic kind");
   1279     }
   1280 }
   1281 
   1282 SpvId SPIRVCodeGenerator::writeSpecialIntrinsic(const FunctionCall& c, SpecialIntrinsic kind,
   1283                                                 SkWStream& out) {
   1284     SpvId result = this->nextId();
   1285     switch (kind) {
   1286         case kAtan_SpecialIntrinsic: {
   1287             std::vector<SpvId> arguments;
   1288             for (size_t i = 0; i < c.fArguments.size(); i++) {
   1289                 arguments.push_back(this->writeExpression(*c.fArguments[i], out));
   1290             }
   1291             this->writeOpCode(SpvOpExtInst, 5 + (int32_t) arguments.size(), out);
   1292             this->writeWord(this->getType(c.fType), out);
   1293             this->writeWord(result, out);
   1294             this->writeWord(fGLSLExtendedInstructions, out);
   1295             this->writeWord(arguments.size() == 2 ? GLSLstd450Atan2 : GLSLstd450Atan, out);
   1296             for (SpvId id : arguments) {
   1297                 this->writeWord(id, out);
   1298             }
   1299             return result;
   1300         }
   1301         case kTexture_SpecialIntrinsic: {
   1302             SpvOp_ op = SpvOpImageSampleImplicitLod;
   1303             switch (c.fArguments[0]->fType.dimensions()) {
   1304                 case SpvDim1D:
   1305                     if (c.fArguments[1]->fType == *fContext.fVec2_Type) {
   1306                         op = SpvOpImageSampleProjImplicitLod;
   1307                     } else {
   1308                         ASSERT(c.fArguments[1]->fType == *fContext.fFloat_Type);
   1309                     }
   1310                     break;
   1311                 case SpvDim2D:
   1312                     if (c.fArguments[1]->fType == *fContext.fVec3_Type) {
   1313                         op = SpvOpImageSampleProjImplicitLod;
   1314                     } else {
   1315                         ASSERT(c.fArguments[1]->fType == *fContext.fVec2_Type);
   1316                     }
   1317                     break;
   1318                 case SpvDim3D:
   1319                     if (c.fArguments[1]->fType == *fContext.fVec4_Type) {
   1320                         op = SpvOpImageSampleProjImplicitLod;
   1321                     } else {
   1322                         ASSERT(c.fArguments[1]->fType == *fContext.fVec3_Type);
   1323                     }
   1324                     break;
   1325                 case SpvDimCube:   // fall through
   1326                 case SpvDimRect:   // fall through
   1327                 case SpvDimBuffer: // fall through
   1328                 case SpvDimSubpassData:
   1329                     break;
   1330             }
   1331             SpvId type = this->getType(c.fType);
   1332             SpvId sampler = this->writeExpression(*c.fArguments[0], out);
   1333             SpvId uv = this->writeExpression(*c.fArguments[1], out);
   1334             if (c.fArguments.size() == 3) {
   1335                 this->writeInstruction(op, type, result, sampler, uv,
   1336                                        SpvImageOperandsBiasMask,
   1337                                        this->writeExpression(*c.fArguments[2], out),
   1338                                        out);
   1339             } else {
   1340                 ASSERT(c.fArguments.size() == 2);
   1341                 this->writeInstruction(op, type, result, sampler, uv,
   1342                                        out);
   1343             }
   1344             break;
   1345         }
   1346         case kSubpassLoad_SpecialIntrinsic: {
   1347             SpvId img = this->writeExpression(*c.fArguments[0], out);
   1348             std::vector<std::unique_ptr<Expression>> args;
   1349             args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
   1350             args.emplace_back(new FloatLiteral(fContext, Position(), 0.0));
   1351             Constructor ctor(Position(), *fContext.fVec2_Type, std::move(args));
   1352             SpvId coords = this->writeConstantVector(ctor);
   1353             if (1 == c.fArguments.size()) {
   1354                 this->writeInstruction(SpvOpImageRead,
   1355                                        this->getType(c.fType),
   1356                                        result,
   1357                                        img,
   1358                                        coords,
   1359                                        out);
   1360             } else {
   1361                 SkASSERT(2 == c.fArguments.size());
   1362                 SpvId sample = this->writeExpression(*c.fArguments[1], out);
   1363                 this->writeInstruction(SpvOpImageRead,
   1364                                        this->getType(c.fType),
   1365                                        result,
   1366                                        img,
   1367                                        coords,
   1368                                        SpvImageOperandsSampleMask,
   1369                                        sample,
   1370                                        out);
   1371             }
   1372             break;
   1373         }
   1374     }
   1375     return result;
   1376 }
   1377 
   1378 SpvId SPIRVCodeGenerator::writeFunctionCall(const FunctionCall& c, SkWStream& out) {
   1379     const auto& entry = fFunctionMap.find(&c.fFunction);
   1380     if (entry == fFunctionMap.end()) {
   1381         return this->writeIntrinsicCall(c, out);
   1382     }
   1383     // stores (variable, type, lvalue) pairs to extract and save after the function call is complete
   1384     std::vector<std::tuple<SpvId, SpvId, std::unique_ptr<LValue>>> lvalues;
   1385     std::vector<SpvId> arguments;
   1386     for (size_t i = 0; i < c.fArguments.size(); i++) {
   1387         // id of temporary variable that we will use to hold this argument, or 0 if it is being
   1388         // passed directly
   1389         SpvId tmpVar;
   1390         // if we need a temporary var to store this argument, this is the value to store in the var
   1391         SpvId tmpValueId;
   1392         if (is_out(*c.fFunction.fParameters[i])) {
   1393             std::unique_ptr<LValue> lv = this->getLValue(*c.fArguments[i], out);
   1394             SpvId ptr = lv->getPointer();
   1395             if (ptr) {
   1396                 arguments.push_back(ptr);
   1397                 continue;
   1398             } else {
   1399                 // lvalue cannot simply be read and written via a pointer (e.g. a swizzle). Need to
   1400                 // copy it into a temp, call the function, read the value out of the temp, and then
   1401                 // update the lvalue.
   1402                 tmpValueId = lv->load(out);
   1403                 tmpVar = this->nextId();
   1404                 lvalues.push_back(std::make_tuple(tmpVar, this->getType(c.fArguments[i]->fType),
   1405                                   std::move(lv)));
   1406             }
   1407         } else {
   1408             // see getFunctionType for an explanation of why we're always using pointer parameters
   1409             tmpValueId = this->writeExpression(*c.fArguments[i], out);
   1410             tmpVar = this->nextId();
   1411         }
   1412         this->writeInstruction(SpvOpVariable,
   1413                                this->getPointerType(c.fArguments[i]->fType,
   1414                                                     SpvStorageClassFunction),
   1415                                tmpVar,
   1416                                SpvStorageClassFunction,
   1417                                fVariableBuffer);
   1418         this->writeInstruction(SpvOpStore, tmpVar, tmpValueId, out);
   1419         arguments.push_back(tmpVar);
   1420     }
   1421     SpvId result = this->nextId();
   1422     this->writeOpCode(SpvOpFunctionCall, 4 + (int32_t) c.fArguments.size(), out);
   1423     this->writeWord(this->getType(c.fType), out);
   1424     this->writeWord(result, out);
   1425     this->writeWord(entry->second, out);
   1426     for (SpvId id : arguments) {
   1427         this->writeWord(id, out);
   1428     }
   1429     // now that the call is complete, we may need to update some lvalues with the new values of out
   1430     // arguments
   1431     for (const auto& tuple : lvalues) {
   1432         SpvId load = this->nextId();
   1433         this->writeInstruction(SpvOpLoad, std::get<1>(tuple), load, std::get<0>(tuple), out);
   1434         std::get<2>(tuple)->store(load, out);
   1435     }
   1436     return result;
   1437 }
   1438 
   1439 SpvId SPIRVCodeGenerator::writeConstantVector(const Constructor& c) {
   1440     ASSERT(c.fType.kind() == Type::kVector_Kind && c.isConstant());
   1441     SpvId result = this->nextId();
   1442     std::vector<SpvId> arguments;
   1443     for (size_t i = 0; i < c.fArguments.size(); i++) {
   1444         arguments.push_back(this->writeExpression(*c.fArguments[i], fConstantBuffer));
   1445     }
   1446     SpvId type = this->getType(c.fType);
   1447     if (c.fArguments.size() == 1) {
   1448         // with a single argument, a vector will have all of its entries equal to the argument
   1449         this->writeOpCode(SpvOpConstantComposite, 3 + c.fType.columns(), fConstantBuffer);
   1450         this->writeWord(type, fConstantBuffer);
   1451         this->writeWord(result, fConstantBuffer);
   1452         for (int i = 0; i < c.fType.columns(); i++) {
   1453             this->writeWord(arguments[0], fConstantBuffer);
   1454         }
   1455     } else {
   1456         this->writeOpCode(SpvOpConstantComposite, 3 + (int32_t) c.fArguments.size(),
   1457                           fConstantBuffer);
   1458         this->writeWord(type, fConstantBuffer);
   1459         this->writeWord(result, fConstantBuffer);
   1460         for (SpvId id : arguments) {
   1461             this->writeWord(id, fConstantBuffer);
   1462         }
   1463     }
   1464     return result;
   1465 }
   1466 
   1467 SpvId SPIRVCodeGenerator::writeFloatConstructor(const Constructor& c, SkWStream& out) {
   1468     ASSERT(c.fType == *fContext.fFloat_Type);
   1469     ASSERT(c.fArguments.size() == 1);
   1470     ASSERT(c.fArguments[0]->fType.isNumber());
   1471     SpvId result = this->nextId();
   1472     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
   1473     if (c.fArguments[0]->fType == *fContext.fInt_Type) {
   1474         this->writeInstruction(SpvOpConvertSToF, this->getType(c.fType), result, parameter,
   1475                                out);
   1476     } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
   1477         this->writeInstruction(SpvOpConvertUToF, this->getType(c.fType), result, parameter,
   1478                                out);
   1479     } else if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
   1480         return parameter;
   1481     }
   1482     return result;
   1483 }
   1484 
   1485 SpvId SPIRVCodeGenerator::writeIntConstructor(const Constructor& c, SkWStream& out) {
   1486     ASSERT(c.fType == *fContext.fInt_Type);
   1487     ASSERT(c.fArguments.size() == 1);
   1488     ASSERT(c.fArguments[0]->fType.isNumber());
   1489     SpvId result = this->nextId();
   1490     SpvId parameter = this->writeExpression(*c.fArguments[0], out);
   1491     if (c.fArguments[0]->fType == *fContext.fFloat_Type) {
   1492         this->writeInstruction(SpvOpConvertFToS, this->getType(c.fType), result, parameter,
   1493                                out);
   1494     } else if (c.fArguments[0]->fType == *fContext.fUInt_Type) {
   1495         this->writeInstruction(SpvOpSatConvertUToS, this->getType(c.fType), result, parameter,
   1496                                out);
   1497     } else if (c.fArguments[0]->fType == *fContext.fInt_Type) {
   1498         return parameter;
   1499     }
   1500     return result;
   1501 }
   1502 
   1503 void SPIRVCodeGenerator::writeUniformScaleMatrix(SpvId id, SpvId diagonal, const Type& type,
   1504                                                  SkWStream& out) {
   1505     FloatLiteral zero(fContext, Position(), 0);
   1506     SpvId zeroId = this->writeFloatLiteral(zero);
   1507     std::vector<SpvId> columnIds;
   1508     for (int column = 0; column < type.columns(); column++) {
   1509         this->writeOpCode(SpvOpCompositeConstruct, 3 + type.rows(),
   1510                           out);
   1511         this->writeWord(this->getType(type.componentType().toCompound(fContext, type.rows(), 1)),
   1512                         out);
   1513         SpvId columnId = this->nextId();
   1514         this->writeWord(columnId, out);
   1515         columnIds.push_back(columnId);
   1516         for (int row = 0; row < type.columns(); row++) {
   1517             this->writeWord(row == column ? diagonal : zeroId, out);
   1518         }
   1519     }
   1520     this->writeOpCode(SpvOpCompositeConstruct, 3 + type.columns(),
   1521                       out);
   1522     this->writeWord(this->getType(type), out);
   1523     this->writeWord(id, out);
   1524     for (SpvId id : columnIds) {
   1525         this->writeWord(id, out);
   1526     }
   1527 }
   1528 
   1529 void SPIRVCodeGenerator::writeMatrixCopy(SpvId id, SpvId src, const Type& srcType,
   1530                                          const Type& dstType, SkWStream& out) {
   1531     ABORT("unimplemented");
   1532 }
   1533 
   1534 SpvId SPIRVCodeGenerator::writeMatrixConstructor(const Constructor& c, SkWStream& out) {
   1535     ASSERT(c.fType.kind() == Type::kMatrix_Kind);
   1536     // go ahead and write the arguments so we don't try to write new instructions in the middle of
   1537     // an instruction
   1538     std::vector<SpvId> arguments;
   1539     for (size_t i = 0; i < c.fArguments.size(); i++) {
   1540         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
   1541     }
   1542     SpvId result = this->nextId();
   1543     int rows = c.fType.rows();
   1544     int columns = c.fType.columns();
   1545     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
   1546         this->writeUniformScaleMatrix(result, arguments[0], c.fType, out);
   1547     } else if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kMatrix_Kind) {
   1548         this->writeMatrixCopy(result, arguments[0], c.fArguments[0]->fType, c.fType, out);
   1549     } else {
   1550         std::vector<SpvId> columnIds;
   1551         int currentCount = 0;
   1552         for (size_t i = 0; i < arguments.size(); i++) {
   1553             if (c.fArguments[i]->fType.kind() == Type::kVector_Kind) {
   1554                 ASSERT(currentCount == 0);
   1555                 columnIds.push_back(arguments[i]);
   1556                 currentCount = 0;
   1557             } else {
   1558                 ASSERT(c.fArguments[i]->fType.kind() == Type::kScalar_Kind);
   1559                 if (currentCount == 0) {
   1560                     this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.rows(), out);
   1561                     this->writeWord(this->getType(c.fType.componentType().toCompound(fContext, rows,
   1562                                                                                      1)),
   1563                                     out);
   1564                     SpvId id = this->nextId();
   1565                     this->writeWord(id, out);
   1566                     columnIds.push_back(id);
   1567                 }
   1568                 this->writeWord(arguments[i], out);
   1569                 currentCount = (currentCount + 1) % rows;
   1570             }
   1571         }
   1572         ASSERT(columnIds.size() == (size_t) columns);
   1573         this->writeOpCode(SpvOpCompositeConstruct, 3 + columns, out);
   1574         this->writeWord(this->getType(c.fType), out);
   1575         this->writeWord(result, out);
   1576         for (SpvId id : columnIds) {
   1577             this->writeWord(id, out);
   1578         }
   1579     }
   1580     return result;
   1581 }
   1582 
   1583 SpvId SPIRVCodeGenerator::writeVectorConstructor(const Constructor& c, SkWStream& out) {
   1584     ASSERT(c.fType.kind() == Type::kVector_Kind);
   1585     if (c.isConstant()) {
   1586         return this->writeConstantVector(c);
   1587     }
   1588     // go ahead and write the arguments so we don't try to write new instructions in the middle of
   1589     // an instruction
   1590     std::vector<SpvId> arguments;
   1591     for (size_t i = 0; i < c.fArguments.size(); i++) {
   1592         arguments.push_back(this->writeExpression(*c.fArguments[i], out));
   1593     }
   1594     SpvId result = this->nextId();
   1595     if (arguments.size() == 1 && c.fArguments[0]->fType.kind() == Type::kScalar_Kind) {
   1596         this->writeOpCode(SpvOpCompositeConstruct, 3 + c.fType.columns(), out);
   1597         this->writeWord(this->getType(c.fType), out);
   1598         this->writeWord(result, out);
   1599         for (int i = 0; i < c.fType.columns(); i++) {
   1600             this->writeWord(arguments[0], out);
   1601         }
   1602     } else {
   1603         this->writeOpCode(SpvOpCompositeConstruct, 3 + (int32_t) c.fArguments.size(), out);
   1604         this->writeWord(this->getType(c.fType), out);
   1605         this->writeWord(result, out);
   1606         for (SpvId id : arguments) {
   1607             this->writeWord(id, out);
   1608         }
   1609     }
   1610     return result;
   1611 }
   1612 
   1613 SpvId SPIRVCodeGenerator::writeConstructor(const Constructor& c, SkWStream& out) {
   1614     if (c.fType == *fContext.fFloat_Type) {
   1615         return this->writeFloatConstructor(c, out);
   1616     } else if (c.fType == *fContext.fInt_Type) {
   1617         return this->writeIntConstructor(c, out);
   1618     }
   1619     switch (c.fType.kind()) {
   1620         case Type::kVector_Kind:
   1621             return this->writeVectorConstructor(c, out);
   1622         case Type::kMatrix_Kind:
   1623             return this->writeMatrixConstructor(c, out);
   1624         default:
   1625             ABORT("unsupported constructor: %s", c.description().c_str());
   1626     }
   1627 }
   1628 
   1629 SpvStorageClass_ get_storage_class(const Modifiers& modifiers) {
   1630     if (modifiers.fFlags & Modifiers::kIn_Flag) {
   1631         ASSERT(!modifiers.fLayout.fPushConstant);
   1632         return SpvStorageClassInput;
   1633     } else if (modifiers.fFlags & Modifiers::kOut_Flag) {
   1634         ASSERT(!modifiers.fLayout.fPushConstant);
   1635         return SpvStorageClassOutput;
   1636     } else if (modifiers.fFlags & Modifiers::kUniform_Flag) {
   1637         if (modifiers.fLayout.fPushConstant) {
   1638             return SpvStorageClassPushConstant;
   1639         }
   1640         return SpvStorageClassUniform;
   1641     } else {
   1642         return SpvStorageClassFunction;
   1643     }
   1644 }
   1645 
   1646 SpvStorageClass_ get_storage_class(const Expression& expr) {
   1647     switch (expr.fKind) {
   1648         case Expression::kVariableReference_Kind:
   1649             return get_storage_class(((VariableReference&) expr).fVariable.fModifiers);
   1650         case Expression::kFieldAccess_Kind:
   1651             return get_storage_class(*((FieldAccess&) expr).fBase);
   1652         case Expression::kIndex_Kind:
   1653             return get_storage_class(*((IndexExpression&) expr).fBase);
   1654         default:
   1655             return SpvStorageClassFunction;
   1656     }
   1657 }
   1658 
   1659 std::vector<SpvId> SPIRVCodeGenerator::getAccessChain(const Expression& expr, SkWStream& out) {
   1660     std::vector<SpvId> chain;
   1661     switch (expr.fKind) {
   1662         case Expression::kIndex_Kind: {
   1663             IndexExpression& indexExpr = (IndexExpression&) expr;
   1664             chain = this->getAccessChain(*indexExpr.fBase, out);
   1665             chain.push_back(this->writeExpression(*indexExpr.fIndex, out));
   1666             break;
   1667         }
   1668         case Expression::kFieldAccess_Kind: {
   1669             FieldAccess& fieldExpr = (FieldAccess&) expr;
   1670             chain = this->getAccessChain(*fieldExpr.fBase, out);
   1671             IntLiteral index(fContext, Position(), fieldExpr.fFieldIndex);
   1672             chain.push_back(this->writeIntLiteral(index));
   1673             break;
   1674         }
   1675         default:
   1676             chain.push_back(this->getLValue(expr, out)->getPointer());
   1677     }
   1678     return chain;
   1679 }
   1680 
   1681 class PointerLValue : public SPIRVCodeGenerator::LValue {
   1682 public:
   1683     PointerLValue(SPIRVCodeGenerator& gen, SpvId pointer, SpvId type)
   1684     : fGen(gen)
   1685     , fPointer(pointer)
   1686     , fType(type) {}
   1687 
   1688     virtual SpvId getPointer() override {
   1689         return fPointer;
   1690     }
   1691 
   1692     virtual SpvId load(SkWStream& out) override {
   1693         SpvId result = fGen.nextId();
   1694         fGen.writeInstruction(SpvOpLoad, fType, result, fPointer, out);
   1695         return result;
   1696     }
   1697 
   1698     virtual void store(SpvId value, SkWStream& out) override {
   1699         fGen.writeInstruction(SpvOpStore, fPointer, value, out);
   1700     }
   1701 
   1702 private:
   1703     SPIRVCodeGenerator& fGen;
   1704     const SpvId fPointer;
   1705     const SpvId fType;
   1706 };
   1707 
   1708 class SwizzleLValue : public SPIRVCodeGenerator::LValue {
   1709 public:
   1710     SwizzleLValue(SPIRVCodeGenerator& gen, SpvId vecPointer, const std::vector<int>& components,
   1711                   const Type& baseType, const Type& swizzleType)
   1712     : fGen(gen)
   1713     , fVecPointer(vecPointer)
   1714     , fComponents(components)
   1715     , fBaseType(baseType)
   1716     , fSwizzleType(swizzleType) {}
   1717 
   1718     virtual SpvId getPointer() override {
   1719         return 0;
   1720     }
   1721 
   1722     virtual SpvId load(SkWStream& out) override {
   1723         SpvId base = fGen.nextId();
   1724         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
   1725         SpvId result = fGen.nextId();
   1726         fGen.writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) fComponents.size(), out);
   1727         fGen.writeWord(fGen.getType(fSwizzleType), out);
   1728         fGen.writeWord(result, out);
   1729         fGen.writeWord(base, out);
   1730         fGen.writeWord(base, out);
   1731         for (int component : fComponents) {
   1732             fGen.writeWord(component, out);
   1733         }
   1734         return result;
   1735     }
   1736 
   1737     virtual void store(SpvId value, SkWStream& out) override {
   1738         // use OpVectorShuffle to mix and match the vector components. We effectively create
   1739         // a virtual vector out of the concatenation of the left and right vectors, and then
   1740         // select components from this virtual vector to make the result vector. For
   1741         // instance, given:
   1742         // vec3 L = ...;
   1743         // vec3 R = ...;
   1744         // L.xz = R.xy;
   1745         // we end up with the virtual vector (L.x, L.y, L.z, R.x, R.y, R.z). Then we want
   1746         // our result vector to look like (R.x, L.y, R.y), so we need to select indices
   1747         // (3, 1, 4).
   1748         SpvId base = fGen.nextId();
   1749         fGen.writeInstruction(SpvOpLoad, fGen.getType(fBaseType), base, fVecPointer, out);
   1750         SpvId shuffle = fGen.nextId();
   1751         fGen.writeOpCode(SpvOpVectorShuffle, 5 + fBaseType.columns(), out);
   1752         fGen.writeWord(fGen.getType(fBaseType), out);
   1753         fGen.writeWord(shuffle, out);
   1754         fGen.writeWord(base, out);
   1755         fGen.writeWord(value, out);
   1756         for (int i = 0; i < fBaseType.columns(); i++) {
   1757             // current offset into the virtual vector, defaults to pulling the unmodified
   1758             // value from the left side
   1759             int offset = i;
   1760             // check to see if we are writing this component
   1761             for (size_t j = 0; j < fComponents.size(); j++) {
   1762                 if (fComponents[j] == i) {
   1763                     // we're writing to this component, so adjust the offset to pull from
   1764                     // the correct component of the right side instead of preserving the
   1765                     // value from the left
   1766                     offset = (int) (j + fBaseType.columns());
   1767                     break;
   1768                 }
   1769             }
   1770             fGen.writeWord(offset, out);
   1771         }
   1772         fGen.writeInstruction(SpvOpStore, fVecPointer, shuffle, out);
   1773     }
   1774 
   1775 private:
   1776     SPIRVCodeGenerator& fGen;
   1777     const SpvId fVecPointer;
   1778     const std::vector<int>& fComponents;
   1779     const Type& fBaseType;
   1780     const Type& fSwizzleType;
   1781 };
   1782 
   1783 std::unique_ptr<SPIRVCodeGenerator::LValue> SPIRVCodeGenerator::getLValue(const Expression& expr,
   1784                                                                           SkWStream& out) {
   1785     switch (expr.fKind) {
   1786         case Expression::kVariableReference_Kind: {
   1787             const Variable& var = ((VariableReference&) expr).fVariable;
   1788             auto entry = fVariableMap.find(&var);
   1789             ASSERT(entry != fVariableMap.end());
   1790             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
   1791                                                                        *this,
   1792                                                                        entry->second,
   1793                                                                        this->getType(expr.fType)));
   1794         }
   1795         case Expression::kIndex_Kind: // fall through
   1796         case Expression::kFieldAccess_Kind: {
   1797             std::vector<SpvId> chain = this->getAccessChain(expr, out);
   1798             SpvId member = this->nextId();
   1799             this->writeOpCode(SpvOpAccessChain, (SpvId) (3 + chain.size()), out);
   1800             this->writeWord(this->getPointerType(expr.fType, get_storage_class(expr)), out);
   1801             this->writeWord(member, out);
   1802             for (SpvId idx : chain) {
   1803                 this->writeWord(idx, out);
   1804             }
   1805             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
   1806                                                                        *this,
   1807                                                                        member,
   1808                                                                        this->getType(expr.fType)));
   1809         }
   1810 
   1811         case Expression::kSwizzle_Kind: {
   1812             Swizzle& swizzle = (Swizzle&) expr;
   1813             size_t count = swizzle.fComponents.size();
   1814             SpvId base = this->getLValue(*swizzle.fBase, out)->getPointer();
   1815             ASSERT(base);
   1816             if (count == 1) {
   1817                 IntLiteral index(fContext, Position(), swizzle.fComponents[0]);
   1818                 SpvId member = this->nextId();
   1819                 this->writeInstruction(SpvOpAccessChain,
   1820                                        this->getPointerType(swizzle.fType,
   1821                                                             get_storage_class(*swizzle.fBase)),
   1822                                        member,
   1823                                        base,
   1824                                        this->writeIntLiteral(index),
   1825                                        out);
   1826                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
   1827                                                                        *this,
   1828                                                                        member,
   1829                                                                        this->getType(expr.fType)));
   1830             } else {
   1831                 return std::unique_ptr<SPIRVCodeGenerator::LValue>(new SwizzleLValue(
   1832                                                                               *this,
   1833                                                                               base,
   1834                                                                               swizzle.fComponents,
   1835                                                                               swizzle.fBase->fType,
   1836                                                                               expr.fType));
   1837             }
   1838         }
   1839 
   1840         default:
   1841             // expr isn't actually an lvalue, create a dummy variable for it. This case happens due
   1842             // to the need to store values in temporary variables during function calls (see
   1843             // comments in getFunctionType); erroneous uses of rvalues as lvalues should have been
   1844             // caught by IRGenerator
   1845             SpvId result = this->nextId();
   1846             SpvId type = this->getPointerType(expr.fType, SpvStorageClassFunction);
   1847             this->writeInstruction(SpvOpVariable, type, result, SpvStorageClassFunction,
   1848                                    fVariableBuffer);
   1849             this->writeInstruction(SpvOpStore, result, this->writeExpression(expr, out), out);
   1850             return std::unique_ptr<SPIRVCodeGenerator::LValue>(new PointerLValue(
   1851                                                                        *this,
   1852                                                                        result,
   1853                                                                        this->getType(expr.fType)));
   1854     }
   1855 }
   1856 
   1857 SpvId SPIRVCodeGenerator::writeVariableReference(const VariableReference& ref, SkWStream& out) {
   1858     SpvId result = this->nextId();
   1859     auto entry = fVariableMap.find(&ref.fVariable);
   1860     ASSERT(entry != fVariableMap.end());
   1861     SpvId var = entry->second;
   1862     this->writeInstruction(SpvOpLoad, this->getType(ref.fVariable.fType), result, var, out);
   1863     if (ref.fVariable.fModifiers.fLayout.fBuiltin == SK_FRAGCOORD_BUILTIN &&
   1864         fProgram.fSettings.fFlipY) {
   1865         // need to remap to a top-left coordinate system
   1866         if (fRTHeightStructId == (SpvId) -1) {
   1867             // height variable hasn't been written yet
   1868             std::shared_ptr<SymbolTable> st(new SymbolTable(fErrors));
   1869             ASSERT(fRTHeightFieldIndex == (SpvId) -1);
   1870             std::vector<Type::Field> fields;
   1871             fields.emplace_back(Modifiers(), SkString(SKSL_RTHEIGHT_NAME),
   1872                                 fContext.fFloat_Type.get());
   1873             SkString name("sksl_synthetic_uniforms");
   1874             Type intfStruct(Position(), name, fields);
   1875             Layout layout(-1, -1, 1, -1, -1, -1, -1, false, false, false, Layout::Format::kUnspecified,
   1876                           false, Layout::kUnspecified_Primitive, -1, -1);
   1877             Variable intfVar(Position(), Modifiers(layout, Modifiers::kUniform_Flag), name,
   1878                              intfStruct, Variable::kGlobal_Storage);
   1879             InterfaceBlock intf(Position(), intfVar, name, SkString(""),
   1880                                 std::vector<std::unique_ptr<Expression>>(), st);
   1881             fRTHeightStructId = this->writeInterfaceBlock(intf);
   1882             fRTHeightFieldIndex = 0;
   1883         }
   1884         ASSERT(fRTHeightFieldIndex != (SpvId) -1);
   1885         // write vec4(gl_FragCoord.x, u_skRTHeight - gl_FragCoord.y, 0.0, 1.0)
   1886         SpvId xId = this->nextId();
   1887         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), xId,
   1888                                result, 0, out);
   1889         IntLiteral fieldIndex(fContext, Position(), fRTHeightFieldIndex);
   1890         SpvId fieldIndexId = this->writeIntLiteral(fieldIndex);
   1891         SpvId heightPtr = this->nextId();
   1892         this->writeOpCode(SpvOpAccessChain, 5, out);
   1893         this->writeWord(this->getPointerType(*fContext.fFloat_Type, SpvStorageClassUniform), out);
   1894         this->writeWord(heightPtr, out);
   1895         this->writeWord(fRTHeightStructId, out);
   1896         this->writeWord(fieldIndexId, out);
   1897         SpvId heightRead = this->nextId();
   1898         this->writeInstruction(SpvOpLoad, this->getType(*fContext.fFloat_Type), heightRead,
   1899                                heightPtr, out);
   1900         SpvId rawYId = this->nextId();
   1901         this->writeInstruction(SpvOpCompositeExtract, this->getType(*fContext.fFloat_Type), rawYId,
   1902                                result, 1, out);
   1903         SpvId flippedYId = this->nextId();
   1904         this->writeInstruction(SpvOpFSub, this->getType(*fContext.fFloat_Type), flippedYId,
   1905                                heightRead, rawYId, out);
   1906         FloatLiteral zero(fContext, Position(), 0.0);
   1907         SpvId zeroId = writeFloatLiteral(zero);
   1908         FloatLiteral one(fContext, Position(), 1.0);
   1909         SpvId oneId = writeFloatLiteral(one);
   1910         SpvId flipped = this->nextId();
   1911         this->writeOpCode(SpvOpCompositeConstruct, 7, out);
   1912         this->writeWord(this->getType(*fContext.fVec4_Type), out);
   1913         this->writeWord(flipped, out);
   1914         this->writeWord(xId, out);
   1915         this->writeWord(flippedYId, out);
   1916         this->writeWord(zeroId, out);
   1917         this->writeWord(oneId, out);
   1918         return flipped;
   1919     }
   1920     return result;
   1921 }
   1922 
   1923 SpvId SPIRVCodeGenerator::writeIndexExpression(const IndexExpression& expr, SkWStream& out) {
   1924     return getLValue(expr, out)->load(out);
   1925 }
   1926 
   1927 SpvId SPIRVCodeGenerator::writeFieldAccess(const FieldAccess& f, SkWStream& out) {
   1928     return getLValue(f, out)->load(out);
   1929 }
   1930 
   1931 SpvId SPIRVCodeGenerator::writeSwizzle(const Swizzle& swizzle, SkWStream& out) {
   1932     SpvId base = this->writeExpression(*swizzle.fBase, out);
   1933     SpvId result = this->nextId();
   1934     size_t count = swizzle.fComponents.size();
   1935     if (count == 1) {
   1936         this->writeInstruction(SpvOpCompositeExtract, this->getType(swizzle.fType), result, base,
   1937                                swizzle.fComponents[0], out);
   1938     } else {
   1939         this->writeOpCode(SpvOpVectorShuffle, 5 + (int32_t) count, out);
   1940         this->writeWord(this->getType(swizzle.fType), out);
   1941         this->writeWord(result, out);
   1942         this->writeWord(base, out);
   1943         this->writeWord(base, out);
   1944         for (int component : swizzle.fComponents) {
   1945             this->writeWord(component, out);
   1946         }
   1947     }
   1948     return result;
   1949 }
   1950 
   1951 SpvId SPIRVCodeGenerator::writeBinaryOperation(const Type& resultType,
   1952                                                const Type& operandType, SpvId lhs,
   1953                                                SpvId rhs, SpvOp_ ifFloat, SpvOp_ ifInt,
   1954                                                SpvOp_ ifUInt, SpvOp_ ifBool, SkWStream& out) {
   1955     SpvId result = this->nextId();
   1956     if (is_float(fContext, operandType)) {
   1957         this->writeInstruction(ifFloat, this->getType(resultType), result, lhs, rhs, out);
   1958     } else if (is_signed(fContext, operandType)) {
   1959         this->writeInstruction(ifInt, this->getType(resultType), result, lhs, rhs, out);
   1960     } else if (is_unsigned(fContext, operandType)) {
   1961         this->writeInstruction(ifUInt, this->getType(resultType), result, lhs, rhs, out);
   1962     } else if (operandType == *fContext.fBool_Type) {
   1963         this->writeInstruction(ifBool, this->getType(resultType), result, lhs, rhs, out);
   1964     } else {
   1965         ABORT("invalid operandType: %s", operandType.description().c_str());
   1966     }
   1967     return result;
   1968 }
   1969 
   1970 bool is_assignment(Token::Kind op) {
   1971     switch (op) {
   1972         case Token::EQ:           // fall through
   1973         case Token::PLUSEQ:       // fall through
   1974         case Token::MINUSEQ:      // fall through
   1975         case Token::STAREQ:       // fall through
   1976         case Token::SLASHEQ:      // fall through
   1977         case Token::PERCENTEQ:    // fall through
   1978         case Token::SHLEQ:        // fall through
   1979         case Token::SHREQ:        // fall through
   1980         case Token::BITWISEOREQ:  // fall through
   1981         case Token::BITWISEXOREQ: // fall through
   1982         case Token::BITWISEANDEQ: // fall through
   1983         case Token::LOGICALOREQ:  // fall through
   1984         case Token::LOGICALXOREQ: // fall through
   1985         case Token::LOGICALANDEQ:
   1986             return true;
   1987         default:
   1988             return false;
   1989     }
   1990 }
   1991 
   1992 SpvId SPIRVCodeGenerator::foldToBool(SpvId id, const Type& operandType, SkWStream& out) {
   1993     if (operandType.kind() == Type::kVector_Kind) {
   1994         SpvId result = this->nextId();
   1995         this->writeInstruction(SpvOpAll, this->getType(*fContext.fBool_Type), result, id, out);
   1996         return result;
   1997     }
   1998     return id;
   1999 }
   2000 
   2001 SpvId SPIRVCodeGenerator::writeBinaryExpression(const BinaryExpression& b, SkWStream& out) {
   2002     // handle cases where we don't necessarily evaluate both LHS and RHS
   2003     switch (b.fOperator) {
   2004         case Token::EQ: {
   2005             SpvId rhs = this->writeExpression(*b.fRight, out);
   2006             this->getLValue(*b.fLeft, out)->store(rhs, out);
   2007             return rhs;
   2008         }
   2009         case Token::LOGICALAND:
   2010             return this->writeLogicalAnd(b, out);
   2011         case Token::LOGICALOR:
   2012             return this->writeLogicalOr(b, out);
   2013         default:
   2014             break;
   2015     }
   2016 
   2017     // "normal" operators
   2018     const Type& resultType = b.fType;
   2019     std::unique_ptr<LValue> lvalue;
   2020     SpvId lhs;
   2021     if (is_assignment(b.fOperator)) {
   2022         lvalue = this->getLValue(*b.fLeft, out);
   2023         lhs = lvalue->load(out);
   2024     } else {
   2025         lvalue = nullptr;
   2026         lhs = this->writeExpression(*b.fLeft, out);
   2027     }
   2028     SpvId rhs = this->writeExpression(*b.fRight, out);
   2029     // component type we are operating on: float, int, uint
   2030     const Type* operandType;
   2031     // IR allows mismatched types in expressions (e.g. vec2 * float), but they need special handling
   2032     // in SPIR-V
   2033     if (b.fLeft->fType != b.fRight->fType) {
   2034         if (b.fLeft->fType.kind() == Type::kVector_Kind &&
   2035             b.fRight->fType.isNumber()) {
   2036             // promote number to vector
   2037             SpvId vec = this->nextId();
   2038             this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
   2039             this->writeWord(this->getType(resultType), out);
   2040             this->writeWord(vec, out);
   2041             for (int i = 0; i < resultType.columns(); i++) {
   2042                 this->writeWord(rhs, out);
   2043             }
   2044             rhs = vec;
   2045             operandType = &b.fRight->fType;
   2046         } else if (b.fRight->fType.kind() == Type::kVector_Kind &&
   2047                    b.fLeft->fType.isNumber()) {
   2048             // promote number to vector
   2049             SpvId vec = this->nextId();
   2050             this->writeOpCode(SpvOpCompositeConstruct, 3 + b.fType.columns(), out);
   2051             this->writeWord(this->getType(resultType), out);
   2052             this->writeWord(vec, out);
   2053             for (int i = 0; i < resultType.columns(); i++) {
   2054                 this->writeWord(lhs, out);
   2055             }
   2056             lhs = vec;
   2057             ASSERT(!lvalue);
   2058             operandType = &b.fLeft->fType;
   2059         } else if (b.fLeft->fType.kind() == Type::kMatrix_Kind) {
   2060             SpvOp_ op;
   2061             if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
   2062                 op = SpvOpMatrixTimesMatrix;
   2063             } else if (b.fRight->fType.kind() == Type::kVector_Kind) {
   2064                 op = SpvOpMatrixTimesVector;
   2065             } else {
   2066                 ASSERT(b.fRight->fType.kind() == Type::kScalar_Kind);
   2067                 op = SpvOpMatrixTimesScalar;
   2068             }
   2069             SpvId result = this->nextId();
   2070             this->writeInstruction(op, this->getType(b.fType), result, lhs, rhs, out);
   2071             if (b.fOperator == Token::STAREQ) {
   2072                 lvalue->store(result, out);
   2073             } else {
   2074                 ASSERT(b.fOperator == Token::STAR);
   2075             }
   2076             return result;
   2077         } else if (b.fRight->fType.kind() == Type::kMatrix_Kind) {
   2078             SpvId result = this->nextId();
   2079             if (b.fLeft->fType.kind() == Type::kVector_Kind) {
   2080                 this->writeInstruction(SpvOpVectorTimesMatrix, this->getType(b.fType), result,
   2081                                        lhs, rhs, out);
   2082             } else {
   2083                 ASSERT(b.fLeft->fType.kind() == Type::kScalar_Kind);
   2084                 this->writeInstruction(SpvOpMatrixTimesScalar, this->getType(b.fType), result, rhs,
   2085                                        lhs, out);
   2086             }
   2087             if (b.fOperator == Token::STAREQ) {
   2088                 lvalue->store(result, out);
   2089             } else {
   2090                 ASSERT(b.fOperator == Token::STAR);
   2091             }
   2092             return result;
   2093         } else {
   2094             ABORT("unsupported binary expression: %s", b.description().c_str());
   2095         }
   2096     } else {
   2097         operandType = &b.fLeft->fType;
   2098         ASSERT(*operandType == b.fRight->fType);
   2099     }
   2100     switch (b.fOperator) {
   2101         case Token::EQEQ: {
   2102             ASSERT(resultType == *fContext.fBool_Type);
   2103             return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
   2104                                                                SpvOpFOrdEqual, SpvOpIEqual,
   2105                                                                SpvOpIEqual, SpvOpLogicalEqual, out),
   2106                                     *operandType, out);
   2107         }
   2108         case Token::NEQ:
   2109             ASSERT(resultType == *fContext.fBool_Type);
   2110             return this->foldToBool(this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
   2111                                                                SpvOpFOrdNotEqual, SpvOpINotEqual,
   2112                                                                SpvOpINotEqual, SpvOpLogicalNotEqual,
   2113                                                                out),
   2114                                     *operandType, out);
   2115         case Token::GT:
   2116             ASSERT(resultType == *fContext.fBool_Type);
   2117             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
   2118                                               SpvOpFOrdGreaterThan, SpvOpSGreaterThan,
   2119                                               SpvOpUGreaterThan, SpvOpUndef, out);
   2120         case Token::LT:
   2121             ASSERT(resultType == *fContext.fBool_Type);
   2122             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFOrdLessThan,
   2123                                               SpvOpSLessThan, SpvOpULessThan, SpvOpUndef, out);
   2124         case Token::GTEQ:
   2125             ASSERT(resultType == *fContext.fBool_Type);
   2126             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
   2127                                               SpvOpFOrdGreaterThanEqual, SpvOpSGreaterThanEqual,
   2128                                               SpvOpUGreaterThanEqual, SpvOpUndef, out);
   2129         case Token::LTEQ:
   2130             ASSERT(resultType == *fContext.fBool_Type);
   2131             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs,
   2132                                               SpvOpFOrdLessThanEqual, SpvOpSLessThanEqual,
   2133                                               SpvOpULessThanEqual, SpvOpUndef, out);
   2134         case Token::PLUS:
   2135             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
   2136                                               SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
   2137         case Token::MINUS:
   2138             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
   2139                                               SpvOpISub, SpvOpISub, SpvOpUndef, out);
   2140         case Token::STAR:
   2141             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
   2142                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
   2143                 // matrix multiply
   2144                 SpvId result = this->nextId();
   2145                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
   2146                                        lhs, rhs, out);
   2147                 return result;
   2148             }
   2149             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
   2150                                               SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
   2151         case Token::SLASH:
   2152             return this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
   2153                                               SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
   2154         case Token::PLUSEQ: {
   2155             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFAdd,
   2156                                                       SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
   2157             ASSERT(lvalue);
   2158             lvalue->store(result, out);
   2159             return result;
   2160         }
   2161         case Token::MINUSEQ: {
   2162             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFSub,
   2163                                                       SpvOpISub, SpvOpISub, SpvOpUndef, out);
   2164             ASSERT(lvalue);
   2165             lvalue->store(result, out);
   2166             return result;
   2167         }
   2168         case Token::STAREQ: {
   2169             if (b.fLeft->fType.kind() == Type::kMatrix_Kind &&
   2170                 b.fRight->fType.kind() == Type::kMatrix_Kind) {
   2171                 // matrix multiply
   2172                 SpvId result = this->nextId();
   2173                 this->writeInstruction(SpvOpMatrixTimesMatrix, this->getType(resultType), result,
   2174                                        lhs, rhs, out);
   2175                 ASSERT(lvalue);
   2176                 lvalue->store(result, out);
   2177                 return result;
   2178             }
   2179             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFMul,
   2180                                                       SpvOpIMul, SpvOpIMul, SpvOpUndef, out);
   2181             ASSERT(lvalue);
   2182             lvalue->store(result, out);
   2183             return result;
   2184         }
   2185         case Token::SLASHEQ: {
   2186             SpvId result = this->writeBinaryOperation(resultType, *operandType, lhs, rhs, SpvOpFDiv,
   2187                                                       SpvOpSDiv, SpvOpUDiv, SpvOpUndef, out);
   2188             ASSERT(lvalue);
   2189             lvalue->store(result, out);
   2190             return result;
   2191         }
   2192         default:
   2193             // FIXME: missing support for some operators (bitwise, &&=, ||=, shift...)
   2194             ABORT("unsupported binary expression: %s", b.description().c_str());
   2195     }
   2196 }
   2197 
   2198 SpvId SPIRVCodeGenerator::writeLogicalAnd(const BinaryExpression& a, SkWStream& out) {
   2199     ASSERT(a.fOperator == Token::LOGICALAND);
   2200     BoolLiteral falseLiteral(fContext, Position(), false);
   2201     SpvId falseConstant = this->writeBoolLiteral(falseLiteral);
   2202     SpvId lhs = this->writeExpression(*a.fLeft, out);
   2203     SpvId rhsLabel = this->nextId();
   2204     SpvId end = this->nextId();
   2205     SpvId lhsBlock = fCurrentBlock;
   2206     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
   2207     this->writeInstruction(SpvOpBranchConditional, lhs, rhsLabel, end, out);
   2208     this->writeLabel(rhsLabel, out);
   2209     SpvId rhs = this->writeExpression(*a.fRight, out);
   2210     SpvId rhsBlock = fCurrentBlock;
   2211     this->writeInstruction(SpvOpBranch, end, out);
   2212     this->writeLabel(end, out);
   2213     SpvId result = this->nextId();
   2214     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, falseConstant,
   2215                            lhsBlock, rhs, rhsBlock, out);
   2216     return result;
   2217 }
   2218 
   2219 SpvId SPIRVCodeGenerator::writeLogicalOr(const BinaryExpression& o, SkWStream& out) {
   2220     ASSERT(o.fOperator == Token::LOGICALOR);
   2221     BoolLiteral trueLiteral(fContext, Position(), true);
   2222     SpvId trueConstant = this->writeBoolLiteral(trueLiteral);
   2223     SpvId lhs = this->writeExpression(*o.fLeft, out);
   2224     SpvId rhsLabel = this->nextId();
   2225     SpvId end = this->nextId();
   2226     SpvId lhsBlock = fCurrentBlock;
   2227     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
   2228     this->writeInstruction(SpvOpBranchConditional, lhs, end, rhsLabel, out);
   2229     this->writeLabel(rhsLabel, out);
   2230     SpvId rhs = this->writeExpression(*o.fRight, out);
   2231     SpvId rhsBlock = fCurrentBlock;
   2232     this->writeInstruction(SpvOpBranch, end, out);
   2233     this->writeLabel(end, out);
   2234     SpvId result = this->nextId();
   2235     this->writeInstruction(SpvOpPhi, this->getType(*fContext.fBool_Type), result, trueConstant,
   2236                            lhsBlock, rhs, rhsBlock, out);
   2237     return result;
   2238 }
   2239 
   2240 SpvId SPIRVCodeGenerator::writeTernaryExpression(const TernaryExpression& t, SkWStream& out) {
   2241     SpvId test = this->writeExpression(*t.fTest, out);
   2242     if (t.fIfTrue->isConstant() && t.fIfFalse->isConstant()) {
   2243         // both true and false are constants, can just use OpSelect
   2244         SpvId result = this->nextId();
   2245         SpvId trueId = this->writeExpression(*t.fIfTrue, out);
   2246         SpvId falseId = this->writeExpression(*t.fIfFalse, out);
   2247         this->writeInstruction(SpvOpSelect, this->getType(t.fType), result, test, trueId, falseId,
   2248                                out);
   2249         return result;
   2250     }
   2251     // was originally using OpPhi to choose the result, but for some reason that is crashing on
   2252     // Adreno. Switched to storing the result in a temp variable as glslang does.
   2253     SpvId var = this->nextId();
   2254     this->writeInstruction(SpvOpVariable, this->getPointerType(t.fType, SpvStorageClassFunction),
   2255                            var, SpvStorageClassFunction, fVariableBuffer);
   2256     SpvId trueLabel = this->nextId();
   2257     SpvId falseLabel = this->nextId();
   2258     SpvId end = this->nextId();
   2259     this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
   2260     this->writeInstruction(SpvOpBranchConditional, test, trueLabel, falseLabel, out);
   2261     this->writeLabel(trueLabel, out);
   2262     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfTrue, out), out);
   2263     this->writeInstruction(SpvOpBranch, end, out);
   2264     this->writeLabel(falseLabel, out);
   2265     this->writeInstruction(SpvOpStore, var, this->writeExpression(*t.fIfFalse, out), out);
   2266     this->writeInstruction(SpvOpBranch, end, out);
   2267     this->writeLabel(end, out);
   2268     SpvId result = this->nextId();
   2269     this->writeInstruction(SpvOpLoad, this->getType(t.fType), result, var, out);
   2270     return result;
   2271 }
   2272 
   2273 std::unique_ptr<Expression> create_literal_1(const Context& context, const Type& type) {
   2274     if (type == *context.fInt_Type) {
   2275         return std::unique_ptr<Expression>(new IntLiteral(context, Position(), 1));
   2276     }
   2277     else if (type == *context.fFloat_Type) {
   2278         return std::unique_ptr<Expression>(new FloatLiteral(context, Position(), 1.0));
   2279     } else {
   2280         ABORT("math is unsupported on type '%s'")
   2281     }
   2282 }
   2283 
   2284 SpvId SPIRVCodeGenerator::writePrefixExpression(const PrefixExpression& p, SkWStream& out) {
   2285     if (p.fOperator == Token::MINUS) {
   2286         SpvId result = this->nextId();
   2287         SpvId typeId = this->getType(p.fType);
   2288         SpvId expr = this->writeExpression(*p.fOperand, out);
   2289         if (is_float(fContext, p.fType)) {
   2290             this->writeInstruction(SpvOpFNegate, typeId, result, expr, out);
   2291         } else if (is_signed(fContext, p.fType)) {
   2292             this->writeInstruction(SpvOpSNegate, typeId, result, expr, out);
   2293         } else {
   2294             ABORT("unsupported prefix expression %s", p.description().c_str());
   2295         };
   2296         return result;
   2297     }
   2298     switch (p.fOperator) {
   2299         case Token::PLUS:
   2300             return this->writeExpression(*p.fOperand, out);
   2301         case Token::PLUSPLUS: {
   2302             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
   2303             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
   2304             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
   2305                                                       SpvOpFAdd, SpvOpIAdd, SpvOpIAdd, SpvOpUndef,
   2306                                                       out);
   2307             lv->store(result, out);
   2308             return result;
   2309         }
   2310         case Token::MINUSMINUS: {
   2311             std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
   2312             SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
   2313             SpvId result = this->writeBinaryOperation(p.fType, p.fType, lv->load(out), one,
   2314                                                       SpvOpFSub, SpvOpISub, SpvOpISub, SpvOpUndef,
   2315                                                       out);
   2316             lv->store(result, out);
   2317             return result;
   2318         }
   2319         case Token::LOGICALNOT: {
   2320             ASSERT(p.fOperand->fType == *fContext.fBool_Type);
   2321             SpvId result = this->nextId();
   2322             this->writeInstruction(SpvOpLogicalNot, this->getType(p.fOperand->fType), result,
   2323                                    this->writeExpression(*p.fOperand, out), out);
   2324             return result;
   2325         }
   2326         case Token::BITWISENOT: {
   2327             SpvId result = this->nextId();
   2328             this->writeInstruction(SpvOpNot, this->getType(p.fOperand->fType), result,
   2329                                    this->writeExpression(*p.fOperand, out), out);
   2330             return result;
   2331         }
   2332         default:
   2333             ABORT("unsupported prefix expression: %s", p.description().c_str());
   2334     }
   2335 }
   2336 
   2337 SpvId SPIRVCodeGenerator::writePostfixExpression(const PostfixExpression& p, SkWStream& out) {
   2338     std::unique_ptr<LValue> lv = this->getLValue(*p.fOperand, out);
   2339     SpvId result = lv->load(out);
   2340     SpvId one = this->writeExpression(*create_literal_1(fContext, p.fType), out);
   2341     switch (p.fOperator) {
   2342         case Token::PLUSPLUS: {
   2343             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFAdd,
   2344                                                     SpvOpIAdd, SpvOpIAdd, SpvOpUndef, out);
   2345             lv->store(temp, out);
   2346             return result;
   2347         }
   2348         case Token::MINUSMINUS: {
   2349             SpvId temp = this->writeBinaryOperation(p.fType, p.fType, result, one, SpvOpFSub,
   2350                                                     SpvOpISub, SpvOpISub, SpvOpUndef, out);
   2351             lv->store(temp, out);
   2352             return result;
   2353         }
   2354         default:
   2355             ABORT("unsupported postfix expression %s", p.description().c_str());
   2356     }
   2357 }
   2358 
   2359 SpvId SPIRVCodeGenerator::writeBoolLiteral(const BoolLiteral& b) {
   2360     if (b.fValue) {
   2361         if (fBoolTrue == 0) {
   2362             fBoolTrue = this->nextId();
   2363             this->writeInstruction(SpvOpConstantTrue, this->getType(b.fType), fBoolTrue,
   2364                                    fConstantBuffer);
   2365         }
   2366         return fBoolTrue;
   2367     } else {
   2368         if (fBoolFalse == 0) {
   2369             fBoolFalse = this->nextId();
   2370             this->writeInstruction(SpvOpConstantFalse, this->getType(b.fType), fBoolFalse,
   2371                                    fConstantBuffer);
   2372         }
   2373         return fBoolFalse;
   2374     }
   2375 }
   2376 
   2377 SpvId SPIRVCodeGenerator::writeIntLiteral(const IntLiteral& i) {
   2378     if (i.fType == *fContext.fInt_Type) {
   2379         auto entry = fIntConstants.find(i.fValue);
   2380         if (entry == fIntConstants.end()) {
   2381             SpvId result = this->nextId();
   2382             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
   2383                                    fConstantBuffer);
   2384             fIntConstants[i.fValue] = result;
   2385             return result;
   2386         }
   2387         return entry->second;
   2388     } else {
   2389         ASSERT(i.fType == *fContext.fUInt_Type);
   2390         auto entry = fUIntConstants.find(i.fValue);
   2391         if (entry == fUIntConstants.end()) {
   2392             SpvId result = this->nextId();
   2393             this->writeInstruction(SpvOpConstant, this->getType(i.fType), result, (SpvId) i.fValue,
   2394                                    fConstantBuffer);
   2395             fUIntConstants[i.fValue] = result;
   2396             return result;
   2397         }
   2398         return entry->second;
   2399     }
   2400 }
   2401 
   2402 SpvId SPIRVCodeGenerator::writeFloatLiteral(const FloatLiteral& f) {
   2403     if (f.fType == *fContext.fFloat_Type) {
   2404         float value = (float) f.fValue;
   2405         auto entry = fFloatConstants.find(value);
   2406         if (entry == fFloatConstants.end()) {
   2407             SpvId result = this->nextId();
   2408             uint32_t bits;
   2409             ASSERT(sizeof(bits) == sizeof(value));
   2410             memcpy(&bits, &value, sizeof(bits));
   2411             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result, bits,
   2412                                    fConstantBuffer);
   2413             fFloatConstants[value] = result;
   2414             return result;
   2415         }
   2416         return entry->second;
   2417     } else {
   2418         ASSERT(f.fType == *fContext.fDouble_Type);
   2419         auto entry = fDoubleConstants.find(f.fValue);
   2420         if (entry == fDoubleConstants.end()) {
   2421             SpvId result = this->nextId();
   2422             uint64_t bits;
   2423             ASSERT(sizeof(bits) == sizeof(f.fValue));
   2424             memcpy(&bits, &f.fValue, sizeof(bits));
   2425             this->writeInstruction(SpvOpConstant, this->getType(f.fType), result,
   2426                                    bits & 0xffffffff, bits >> 32, fConstantBuffer);
   2427             fDoubleConstants[f.fValue] = result;
   2428             return result;
   2429         }
   2430         return entry->second;
   2431     }
   2432 }
   2433 
   2434 SpvId SPIRVCodeGenerator::writeFunctionStart(const FunctionDeclaration& f, SkWStream& out) {
   2435     SpvId result = fFunctionMap[&f];
   2436     this->writeInstruction(SpvOpFunction, this->getType(f.fReturnType), result,
   2437                            SpvFunctionControlMaskNone, this->getFunctionType(f), out);
   2438     this->writeInstruction(SpvOpName, result, f.fName.c_str(), fNameBuffer);
   2439     for (size_t i = 0; i < f.fParameters.size(); i++) {
   2440         SpvId id = this->nextId();
   2441         fVariableMap[f.fParameters[i]] = id;
   2442         SpvId type;
   2443         type = this->getPointerType(f.fParameters[i]->fType, SpvStorageClassFunction);
   2444         this->writeInstruction(SpvOpFunctionParameter, type, id, out);
   2445     }
   2446     return result;
   2447 }
   2448 
   2449 SpvId SPIRVCodeGenerator::writeFunction(const FunctionDefinition& f, SkWStream& out) {
   2450     SpvId result = this->writeFunctionStart(f.fDeclaration, out);
   2451     this->writeLabel(this->nextId(), out);
   2452     if (f.fDeclaration.fName == "main") {
   2453         write_data(*fGlobalInitializersBuffer.detachAsData(), out);
   2454     }
   2455     SkDynamicMemoryWStream bodyBuffer;
   2456     this->writeBlock(*f.fBody, bodyBuffer);
   2457     write_data(*fVariableBuffer.detachAsData(), out);
   2458     write_data(*bodyBuffer.detachAsData(), out);
   2459     if (fCurrentBlock) {
   2460         this->writeInstruction(SpvOpReturn, out);
   2461     }
   2462     this->writeInstruction(SpvOpFunctionEnd, out);
   2463     return result;
   2464 }
   2465 
   2466 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target) {
   2467     if (layout.fLocation >= 0) {
   2468         this->writeInstruction(SpvOpDecorate, target, SpvDecorationLocation, layout.fLocation,
   2469                                fDecorationBuffer);
   2470     }
   2471     if (layout.fBinding >= 0) {
   2472         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBinding, layout.fBinding,
   2473                                fDecorationBuffer);
   2474     }
   2475     if (layout.fIndex >= 0) {
   2476         this->writeInstruction(SpvOpDecorate, target, SpvDecorationIndex, layout.fIndex,
   2477                                fDecorationBuffer);
   2478     }
   2479     if (layout.fSet >= 0) {
   2480         this->writeInstruction(SpvOpDecorate, target, SpvDecorationDescriptorSet, layout.fSet,
   2481                                fDecorationBuffer);
   2482     }
   2483     if (layout.fInputAttachmentIndex >= 0) {
   2484         this->writeInstruction(SpvOpDecorate, target, SpvDecorationInputAttachmentIndex,
   2485                                layout.fInputAttachmentIndex, fDecorationBuffer);
   2486     }
   2487     if (layout.fBuiltin >= 0 && layout.fBuiltin != SK_FRAGCOLOR_BUILTIN) {
   2488         this->writeInstruction(SpvOpDecorate, target, SpvDecorationBuiltIn, layout.fBuiltin,
   2489                                fDecorationBuffer);
   2490     }
   2491 }
   2492 
   2493 void SPIRVCodeGenerator::writeLayout(const Layout& layout, SpvId target, int member) {
   2494     if (layout.fLocation >= 0) {
   2495         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationLocation,
   2496                                layout.fLocation, fDecorationBuffer);
   2497     }
   2498     if (layout.fBinding >= 0) {
   2499         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBinding,
   2500                                layout.fBinding, fDecorationBuffer);
   2501     }
   2502     if (layout.fIndex >= 0) {
   2503         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationIndex,
   2504                                layout.fIndex, fDecorationBuffer);
   2505     }
   2506     if (layout.fSet >= 0) {
   2507         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationDescriptorSet,
   2508                                layout.fSet, fDecorationBuffer);
   2509     }
   2510     if (layout.fInputAttachmentIndex >= 0) {
   2511         this->writeInstruction(SpvOpDecorate, target, member, SpvDecorationInputAttachmentIndex,
   2512                                layout.fInputAttachmentIndex, fDecorationBuffer);
   2513     }
   2514     if (layout.fBuiltin >= 0) {
   2515         this->writeInstruction(SpvOpMemberDecorate, target, member, SpvDecorationBuiltIn,
   2516                                layout.fBuiltin, fDecorationBuffer);
   2517     }
   2518 }
   2519 
   2520 SpvId SPIRVCodeGenerator::writeInterfaceBlock(const InterfaceBlock& intf) {
   2521     MemoryLayout layout = intf.fVariable.fModifiers.fLayout.fPushConstant ?
   2522                           MemoryLayout(MemoryLayout::k430_Standard) :
   2523                           fDefaultLayout;
   2524     SpvId result = this->nextId();
   2525     const Type* type = &intf.fVariable.fType;
   2526     if (fProgram.fInputs.fRTHeight) {
   2527         ASSERT(fRTHeightStructId == (SpvId) -1);
   2528         ASSERT(fRTHeightFieldIndex == (SpvId) -1);
   2529         std::vector<Type::Field> fields = type->fields();
   2530         fRTHeightStructId = result;
   2531         fRTHeightFieldIndex = fields.size();
   2532         fields.emplace_back(Modifiers(), SkString(SKSL_RTHEIGHT_NAME), fContext.fFloat_Type.get());
   2533         type = new Type(type->fPosition, type->name(), fields);
   2534     }
   2535     SpvId typeId = this->getType(*type, layout);
   2536     this->writeInstruction(SpvOpDecorate, typeId, SpvDecorationBlock, fDecorationBuffer);
   2537     SpvStorageClass_ storageClass = get_storage_class(intf.fVariable.fModifiers);
   2538     SpvId ptrType = this->nextId();
   2539     this->writeInstruction(SpvOpTypePointer, ptrType, storageClass, typeId, fConstantBuffer);
   2540     this->writeInstruction(SpvOpVariable, ptrType, result, storageClass, fConstantBuffer);
   2541     this->writeLayout(intf.fVariable.fModifiers.fLayout, result);
   2542     fVariableMap[&intf.fVariable] = result;
   2543     if (fProgram.fInputs.fRTHeight) {
   2544         delete type;
   2545     }
   2546     return result;
   2547 }
   2548 
   2549 #define BUILTIN_IGNORE 9999
   2550 void SPIRVCodeGenerator::writeGlobalVars(Program::Kind kind, const VarDeclarations& decl,
   2551                                          SkWStream& out) {
   2552     for (size_t i = 0; i < decl.fVars.size(); i++) {
   2553         const VarDeclaration& varDecl = decl.fVars[i];
   2554         const Variable* var = varDecl.fVar;
   2555         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
   2556         // in the OpenGL backend.
   2557         ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
   2558                                            Modifiers::kWriteOnly_Flag |
   2559                                            Modifiers::kCoherent_Flag |
   2560                                            Modifiers::kVolatile_Flag |
   2561                                            Modifiers::kRestrict_Flag)));
   2562         if (var->fModifiers.fLayout.fBuiltin == BUILTIN_IGNORE) {
   2563             continue;
   2564         }
   2565         if (var->fModifiers.fLayout.fBuiltin == SK_FRAGCOLOR_BUILTIN &&
   2566             kind != Program::kFragment_Kind) {
   2567             continue;
   2568         }
   2569         if (!var->fReadCount && !var->fWriteCount &&
   2570                 !(var->fModifiers.fFlags & (Modifiers::kIn_Flag |
   2571                                             Modifiers::kOut_Flag |
   2572                                             Modifiers::kUniform_Flag))) {
   2573             // variable is dead and not an input / output var (the Vulkan debug layers complain if
   2574             // we elide an interface var, even if it's dead)
   2575             continue;
   2576         }
   2577         SpvStorageClass_ storageClass;
   2578         if (var->fModifiers.fFlags & Modifiers::kIn_Flag) {
   2579             storageClass = SpvStorageClassInput;
   2580         } else if (var->fModifiers.fFlags & Modifiers::kOut_Flag) {
   2581             storageClass = SpvStorageClassOutput;
   2582         } else if (var->fModifiers.fFlags & Modifiers::kUniform_Flag) {
   2583             if (var->fType.kind() == Type::kSampler_Kind) {
   2584                 storageClass = SpvStorageClassUniformConstant;
   2585             } else {
   2586                 storageClass = SpvStorageClassUniform;
   2587             }
   2588         } else {
   2589             storageClass = SpvStorageClassPrivate;
   2590         }
   2591         SpvId id = this->nextId();
   2592         fVariableMap[var] = id;
   2593         SpvId type = this->getPointerType(var->fType, storageClass);
   2594         this->writeInstruction(SpvOpVariable, type, id, storageClass, fConstantBuffer);
   2595         this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
   2596         if (var->fType.kind() == Type::kMatrix_Kind) {
   2597             this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationColMajor,
   2598                                    fDecorationBuffer);
   2599             this->writeInstruction(SpvOpMemberDecorate, id, (SpvId) i, SpvDecorationMatrixStride,
   2600                                    (SpvId) fDefaultLayout.stride(var->fType), fDecorationBuffer);
   2601         }
   2602         if (varDecl.fValue) {
   2603             ASSERT(!fCurrentBlock);
   2604             fCurrentBlock = -1;
   2605             SpvId value = this->writeExpression(*varDecl.fValue, fGlobalInitializersBuffer);
   2606             this->writeInstruction(SpvOpStore, id, value, fGlobalInitializersBuffer);
   2607             fCurrentBlock = 0;
   2608         }
   2609         this->writeLayout(var->fModifiers.fLayout, id);
   2610     }
   2611 }
   2612 
   2613 void SPIRVCodeGenerator::writeVarDeclarations(const VarDeclarations& decl, SkWStream& out) {
   2614     for (const auto& varDecl : decl.fVars) {
   2615         const Variable* var = varDecl.fVar;
   2616         // These haven't been implemented in our SPIR-V generator yet and we only currently use them
   2617         // in the OpenGL backend.
   2618         ASSERT(!(var->fModifiers.fFlags & (Modifiers::kReadOnly_Flag |
   2619                                            Modifiers::kWriteOnly_Flag |
   2620                                            Modifiers::kCoherent_Flag |
   2621                                            Modifiers::kVolatile_Flag |
   2622                                            Modifiers::kRestrict_Flag)));
   2623         SpvId id = this->nextId();
   2624         fVariableMap[var] = id;
   2625         SpvId type = this->getPointerType(var->fType, SpvStorageClassFunction);
   2626         this->writeInstruction(SpvOpVariable, type, id, SpvStorageClassFunction, fVariableBuffer);
   2627         this->writeInstruction(SpvOpName, id, var->fName.c_str(), fNameBuffer);
   2628         if (varDecl.fValue) {
   2629             SpvId value = this->writeExpression(*varDecl.fValue, out);
   2630             this->writeInstruction(SpvOpStore, id, value, out);
   2631         }
   2632     }
   2633 }
   2634 
   2635 void SPIRVCodeGenerator::writeStatement(const Statement& s, SkWStream& out) {
   2636     switch (s.fKind) {
   2637         case Statement::kBlock_Kind:
   2638             this->writeBlock((Block&) s, out);
   2639             break;
   2640         case Statement::kExpression_Kind:
   2641             this->writeExpression(*((ExpressionStatement&) s).fExpression, out);
   2642             break;
   2643         case Statement::kReturn_Kind:
   2644             this->writeReturnStatement((ReturnStatement&) s, out);
   2645             break;
   2646         case Statement::kVarDeclarations_Kind:
   2647             this->writeVarDeclarations(*((VarDeclarationsStatement&) s).fDeclaration, out);
   2648             break;
   2649         case Statement::kIf_Kind:
   2650             this->writeIfStatement((IfStatement&) s, out);
   2651             break;
   2652         case Statement::kFor_Kind:
   2653             this->writeForStatement((ForStatement&) s, out);
   2654             break;
   2655         case Statement::kWhile_Kind:
   2656             this->writeWhileStatement((WhileStatement&) s, out);
   2657             break;
   2658         case Statement::kDo_Kind:
   2659             this->writeDoStatement((DoStatement&) s, out);
   2660             break;
   2661         case Statement::kBreak_Kind:
   2662             this->writeInstruction(SpvOpBranch, fBreakTarget.top(), out);
   2663             break;
   2664         case Statement::kContinue_Kind:
   2665             this->writeInstruction(SpvOpBranch, fContinueTarget.top(), out);
   2666             break;
   2667         case Statement::kDiscard_Kind:
   2668             this->writeInstruction(SpvOpKill, out);
   2669             break;
   2670         default:
   2671             ABORT("unsupported statement: %s", s.description().c_str());
   2672     }
   2673 }
   2674 
   2675 void SPIRVCodeGenerator::writeBlock(const Block& b, SkWStream& out) {
   2676     for (size_t i = 0; i < b.fStatements.size(); i++) {
   2677         this->writeStatement(*b.fStatements[i], out);
   2678     }
   2679 }
   2680 
   2681 void SPIRVCodeGenerator::writeIfStatement(const IfStatement& stmt, SkWStream& out) {
   2682     SpvId test = this->writeExpression(*stmt.fTest, out);
   2683     SpvId ifTrue = this->nextId();
   2684     SpvId ifFalse = this->nextId();
   2685     if (stmt.fIfFalse) {
   2686         SpvId end = this->nextId();
   2687         this->writeInstruction(SpvOpSelectionMerge, end, SpvSelectionControlMaskNone, out);
   2688         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
   2689         this->writeLabel(ifTrue, out);
   2690         this->writeStatement(*stmt.fIfTrue, out);
   2691         if (fCurrentBlock) {
   2692             this->writeInstruction(SpvOpBranch, end, out);
   2693         }
   2694         this->writeLabel(ifFalse, out);
   2695         this->writeStatement(*stmt.fIfFalse, out);
   2696         if (fCurrentBlock) {
   2697             this->writeInstruction(SpvOpBranch, end, out);
   2698         }
   2699         this->writeLabel(end, out);
   2700     } else {
   2701         this->writeInstruction(SpvOpSelectionMerge, ifFalse, SpvSelectionControlMaskNone, out);
   2702         this->writeInstruction(SpvOpBranchConditional, test, ifTrue, ifFalse, out);
   2703         this->writeLabel(ifTrue, out);
   2704         this->writeStatement(*stmt.fIfTrue, out);
   2705         if (fCurrentBlock) {
   2706             this->writeInstruction(SpvOpBranch, ifFalse, out);
   2707         }
   2708         this->writeLabel(ifFalse, out);
   2709     }
   2710 }
   2711 
   2712 void SPIRVCodeGenerator::writeForStatement(const ForStatement& f, SkWStream& out) {
   2713     if (f.fInitializer) {
   2714         this->writeStatement(*f.fInitializer, out);
   2715     }
   2716     SpvId header = this->nextId();
   2717     SpvId start = this->nextId();
   2718     SpvId body = this->nextId();
   2719     SpvId next = this->nextId();
   2720     fContinueTarget.push(next);
   2721     SpvId end = this->nextId();
   2722     fBreakTarget.push(end);
   2723     this->writeInstruction(SpvOpBranch, header, out);
   2724     this->writeLabel(header, out);
   2725     this->writeInstruction(SpvOpLoopMerge, end, next, SpvLoopControlMaskNone, out);
   2726     this->writeInstruction(SpvOpBranch, start, out);
   2727     this->writeLabel(start, out);
   2728     if (f.fTest) {
   2729         SpvId test = this->writeExpression(*f.fTest, out);
   2730         this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
   2731     }
   2732     this->writeLabel(body, out);
   2733     this->writeStatement(*f.fStatement, out);
   2734     if (fCurrentBlock) {
   2735         this->writeInstruction(SpvOpBranch, next, out);
   2736     }
   2737     this->writeLabel(next, out);
   2738     if (f.fNext) {
   2739         this->writeExpression(*f.fNext, out);
   2740     }
   2741     this->writeInstruction(SpvOpBranch, header, out);
   2742     this->writeLabel(end, out);
   2743     fBreakTarget.pop();
   2744     fContinueTarget.pop();
   2745 }
   2746 
   2747 void SPIRVCodeGenerator::writeWhileStatement(const WhileStatement& w, SkWStream& out) {
   2748     // We believe the while loop code below will work, but Skia doesn't actually use them and
   2749     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
   2750     // the time being, we just fail with an error due to the lack of testing. If you encounter this
   2751     // message, simply remove the error call below to see whether our while loop support actually
   2752     // works.
   2753     fErrors.error(w.fPosition, "internal error: while loop support has been disabled in SPIR-V, "
   2754                   "see SkSLSPIRVCodeGenerator.cpp for details");
   2755 
   2756     SpvId header = this->nextId();
   2757     SpvId start = this->nextId();
   2758     SpvId body = this->nextId();
   2759     fContinueTarget.push(start);
   2760     SpvId end = this->nextId();
   2761     fBreakTarget.push(end);
   2762     this->writeInstruction(SpvOpBranch, header, out);
   2763     this->writeLabel(header, out);
   2764     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
   2765     this->writeInstruction(SpvOpBranch, start, out);
   2766     this->writeLabel(start, out);
   2767     SpvId test = this->writeExpression(*w.fTest, out);
   2768     this->writeInstruction(SpvOpBranchConditional, test, body, end, out);
   2769     this->writeLabel(body, out);
   2770     this->writeStatement(*w.fStatement, out);
   2771     if (fCurrentBlock) {
   2772         this->writeInstruction(SpvOpBranch, start, out);
   2773     }
   2774     this->writeLabel(end, out);
   2775     fBreakTarget.pop();
   2776     fContinueTarget.pop();
   2777 }
   2778 
   2779 void SPIRVCodeGenerator::writeDoStatement(const DoStatement& d, SkWStream& out) {
   2780     // We believe the do loop code below will work, but Skia doesn't actually use them and
   2781     // adequately testing this code in the absence of Skia exercising it isn't straightforward. For
   2782     // the time being, we just fail with an error due to the lack of testing. If you encounter this
   2783     // message, simply remove the error call below to see whether our do loop support actually
   2784     // works.
   2785     fErrors.error(d.fPosition, "internal error: do loop support has been disabled in SPIR-V, see "
   2786                   "SkSLSPIRVCodeGenerator.cpp for details");
   2787 
   2788     SpvId header = this->nextId();
   2789     SpvId start = this->nextId();
   2790     SpvId next = this->nextId();
   2791     fContinueTarget.push(next);
   2792     SpvId end = this->nextId();
   2793     fBreakTarget.push(end);
   2794     this->writeInstruction(SpvOpBranch, header, out);
   2795     this->writeLabel(header, out);
   2796     this->writeInstruction(SpvOpLoopMerge, end, start, SpvLoopControlMaskNone, out);
   2797     this->writeInstruction(SpvOpBranch, start, out);
   2798     this->writeLabel(start, out);
   2799     this->writeStatement(*d.fStatement, out);
   2800     if (fCurrentBlock) {
   2801         this->writeInstruction(SpvOpBranch, next, out);
   2802     }
   2803     this->writeLabel(next, out);
   2804     SpvId test = this->writeExpression(*d.fTest, out);
   2805     this->writeInstruction(SpvOpBranchConditional, test, start, end, out);
   2806     this->writeLabel(end, out);
   2807     fBreakTarget.pop();
   2808     fContinueTarget.pop();
   2809 }
   2810 
   2811 void SPIRVCodeGenerator::writeReturnStatement(const ReturnStatement& r, SkWStream& out) {
   2812     if (r.fExpression) {
   2813         this->writeInstruction(SpvOpReturnValue, this->writeExpression(*r.fExpression, out),
   2814                                out);
   2815     } else {
   2816         this->writeInstruction(SpvOpReturn, out);
   2817     }
   2818 }
   2819 
   2820 void SPIRVCodeGenerator::writeInstructions(const Program& program, SkWStream& out) {
   2821     fGLSLExtendedInstructions = this->nextId();
   2822     SkDynamicMemoryWStream body;
   2823     std::set<SpvId> interfaceVars;
   2824     // assign IDs to functions
   2825     for (size_t i = 0; i < program.fElements.size(); i++) {
   2826         if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
   2827             FunctionDefinition& f = (FunctionDefinition&) *program.fElements[i];
   2828             fFunctionMap[&f.fDeclaration] = this->nextId();
   2829         }
   2830     }
   2831     for (size_t i = 0; i < program.fElements.size(); i++) {
   2832         if (program.fElements[i]->fKind == ProgramElement::kInterfaceBlock_Kind) {
   2833             InterfaceBlock& intf = (InterfaceBlock&) *program.fElements[i];
   2834             SpvId id = this->writeInterfaceBlock(intf);
   2835             if ((intf.fVariable.fModifiers.fFlags & Modifiers::kIn_Flag) ||
   2836                 (intf.fVariable.fModifiers.fFlags & Modifiers::kOut_Flag)) {
   2837                 interfaceVars.insert(id);
   2838             }
   2839         }
   2840     }
   2841     for (size_t i = 0; i < program.fElements.size(); i++) {
   2842         if (program.fElements[i]->fKind == ProgramElement::kVar_Kind) {
   2843             this->writeGlobalVars(program.fKind, ((VarDeclarations&) *program.fElements[i]),
   2844                                   body);
   2845         }
   2846     }
   2847     for (size_t i = 0; i < program.fElements.size(); i++) {
   2848         if (program.fElements[i]->fKind == ProgramElement::kFunction_Kind) {
   2849             this->writeFunction(((FunctionDefinition&) *program.fElements[i]), body);
   2850         }
   2851     }
   2852     const FunctionDeclaration* main = nullptr;
   2853     for (auto entry : fFunctionMap) {
   2854         if (entry.first->fName == "main") {
   2855             main = entry.first;
   2856         }
   2857     }
   2858     ASSERT(main);
   2859     for (auto entry : fVariableMap) {
   2860         const Variable* var = entry.first;
   2861         if (var->fStorage == Variable::kGlobal_Storage &&
   2862                 ((var->fModifiers.fFlags & Modifiers::kIn_Flag) ||
   2863                  (var->fModifiers.fFlags & Modifiers::kOut_Flag))) {
   2864             interfaceVars.insert(entry.second);
   2865         }
   2866     }
   2867     this->writeCapabilities(out);
   2868     this->writeInstruction(SpvOpExtInstImport, fGLSLExtendedInstructions, "GLSL.std.450", out);
   2869     this->writeInstruction(SpvOpMemoryModel, SpvAddressingModelLogical, SpvMemoryModelGLSL450, out);
   2870     this->writeOpCode(SpvOpEntryPoint, (SpvId) (3 + (strlen(main->fName.c_str()) + 4) / 4) +
   2871                       (int32_t) interfaceVars.size(), out);
   2872     switch (program.fKind) {
   2873         case Program::kVertex_Kind:
   2874             this->writeWord(SpvExecutionModelVertex, out);
   2875             break;
   2876         case Program::kFragment_Kind:
   2877             this->writeWord(SpvExecutionModelFragment, out);
   2878             break;
   2879         case Program::kGeometry_Kind:
   2880             this->writeWord(SpvExecutionModelGeometry, out);
   2881             break;
   2882     }
   2883     this->writeWord(fFunctionMap[main], out);
   2884     this->writeString(main->fName.c_str(), out);
   2885     for (int var : interfaceVars) {
   2886         this->writeWord(var, out);
   2887     }
   2888     if (program.fKind == Program::kFragment_Kind) {
   2889         this->writeInstruction(SpvOpExecutionMode,
   2890                                fFunctionMap[main],
   2891                                SpvExecutionModeOriginUpperLeft,
   2892                                out);
   2893     }
   2894     for (size_t i = 0; i < program.fElements.size(); i++) {
   2895         if (program.fElements[i]->fKind == ProgramElement::kExtension_Kind) {
   2896             this->writeInstruction(SpvOpSourceExtension,
   2897                                    ((Extension&) *program.fElements[i]).fName.c_str(),
   2898                                    out);
   2899         }
   2900     }
   2901 
   2902     write_data(*fExtraGlobalsBuffer.detachAsData(), out);
   2903     write_data(*fNameBuffer.detachAsData(), out);
   2904     write_data(*fDecorationBuffer.detachAsData(), out);
   2905     write_data(*fConstantBuffer.detachAsData(), out);
   2906     write_data(*fExternalFunctionsBuffer.detachAsData(), out);
   2907     write_data(*body.detachAsData(), out);
   2908 }
   2909 
   2910 bool SPIRVCodeGenerator::generateCode() {
   2911     ASSERT(!fErrors.errorCount());
   2912     this->writeWord(SpvMagicNumber, *fOut);
   2913     this->writeWord(SpvVersion, *fOut);
   2914     this->writeWord(SKSL_MAGIC, *fOut);
   2915     SkDynamicMemoryWStream buffer;
   2916     this->writeInstructions(fProgram, buffer);
   2917     this->writeWord(fIdCount, *fOut);
   2918     this->writeWord(0, *fOut); // reserved, always zero
   2919     write_data(*buffer.detachAsData(), *fOut);
   2920     return 0 == fErrors.errorCount();
   2921 }
   2922 
   2923 }
   2924