Home | History | Annotate | Download | only in SPIRV
      1 //
      2 //Copyright (C) 2015 LunarG, Inc.
      3 //
      4 //All rights reserved.
      5 //
      6 //Redistribution and use in source and binary forms, with or without
      7 //modification, are permitted provided that the following conditions
      8 //are met:
      9 //
     10 //    Redistributions of source code must retain the above copyright
     11 //    notice, this list of conditions and the following disclaimer.
     12 //
     13 //    Redistributions in binary form must reproduce the above
     14 //    copyright notice, this list of conditions and the following
     15 //    disclaimer in the documentation and/or other materials provided
     16 //    with the distribution.
     17 //
     18 //    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
     19 //    contributors may be used to endorse or promote products derived
     20 //    from this software without specific prior written permission.
     21 //
     22 //THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     23 //"AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     24 //LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
     25 //FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
     26 //COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
     27 //INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
     28 //BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
     29 //LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
     30 //CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     31 //LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
     32 //ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     33 //POSSIBILITY OF SUCH DAMAGE.
     34 //
     35 
     36 #include "SPVRemapper.h"
     37 #include "doc.h"
     38 
     39 #if !defined (use_cpp11)
     40 // ... not supported before C++11
     41 #else // defined (use_cpp11)
     42 
     43 #include <algorithm>
     44 #include <cassert>
     45 #include "../glslang/Include/Common.h"
     46 
     47 namespace spv {
     48 
     49     // By default, just abort on error.  Can be overridden via RegisterErrorHandler
     50     spirvbin_t::errorfn_t spirvbin_t::errorHandler = [](const std::string&) { exit(5); };
     51     // By default, eat log messages.  Can be overridden via RegisterLogHandler
     52     spirvbin_t::logfn_t   spirvbin_t::logHandler   = [](const std::string&) { };
     53 
     54     // This can be overridden to provide other message behavior if needed
     55     void spirvbin_t::msg(int minVerbosity, int indent, const std::string& txt) const
     56     {
     57         if (verbose >= minVerbosity)
     58             logHandler(std::string(indent, ' ') + txt);
     59     }
     60 
     61     // hash opcode, with special handling for OpExtInst
     62     std::uint32_t spirvbin_t::asOpCodeHash(unsigned word)
     63     {
     64         const spv::Op opCode = asOpCode(word);
     65 
     66         std::uint32_t offset = 0;
     67 
     68         switch (opCode) {
     69         case spv::OpExtInst:
     70             offset += asId(word + 4); break;
     71         default:
     72             break;
     73         }
     74 
     75         return opCode * 19 + offset; // 19 = small prime
     76     }
     77 
     78     spirvbin_t::range_t spirvbin_t::literalRange(spv::Op opCode) const
     79     {
     80         static const int maxCount = 1<<30;
     81 
     82         switch (opCode) {
     83         case spv::OpTypeFloat:        // fall through...
     84         case spv::OpTypePointer:      return range_t(2, 3);
     85         case spv::OpTypeInt:          return range_t(2, 4);
     86         // TODO: case spv::OpTypeImage:
     87         // TODO: case spv::OpTypeSampledImage:
     88         case spv::OpTypeSampler:      return range_t(3, 8);
     89         case spv::OpTypeVector:       // fall through
     90         case spv::OpTypeMatrix:       // ...
     91         case spv::OpTypePipe:         return range_t(3, 4);
     92         case spv::OpConstant:         return range_t(3, maxCount);
     93         default:                      return range_t(0, 0);
     94         }
     95     }
     96 
     97     spirvbin_t::range_t spirvbin_t::typeRange(spv::Op opCode) const
     98     {
     99         static const int maxCount = 1<<30;
    100 
    101         if (isConstOp(opCode))
    102             return range_t(1, 2);
    103 
    104         switch (opCode) {
    105         case spv::OpTypeVector:       // fall through
    106         case spv::OpTypeMatrix:       // ...
    107         case spv::OpTypeSampler:      // ...
    108         case spv::OpTypeArray:        // ...
    109         case spv::OpTypeRuntimeArray: // ...
    110         case spv::OpTypePipe:         return range_t(2, 3);
    111         case spv::OpTypeStruct:       // fall through
    112         case spv::OpTypeFunction:     return range_t(2, maxCount);
    113         case spv::OpTypePointer:      return range_t(3, 4);
    114         default:                      return range_t(0, 0);
    115         }
    116     }
    117 
    118     spirvbin_t::range_t spirvbin_t::constRange(spv::Op opCode) const
    119     {
    120         static const int maxCount = 1<<30;
    121 
    122         switch (opCode) {
    123         case spv::OpTypeArray:         // fall through...
    124         case spv::OpTypeRuntimeArray:  return range_t(3, 4);
    125         case spv::OpConstantComposite: return range_t(3, maxCount);
    126         default:                       return range_t(0, 0);
    127         }
    128     }
    129 
    130     // Is this an opcode we should remove when using --strip?
    131     bool spirvbin_t::isStripOp(spv::Op opCode) const
    132     {
    133         switch (opCode) {
    134         case spv::OpSource:
    135         case spv::OpSourceExtension:
    136         case spv::OpName:
    137         case spv::OpMemberName:
    138         case spv::OpLine:           return true;
    139         default:                    return false;
    140         }
    141     }
    142 
    143     bool spirvbin_t::isFlowCtrl(spv::Op opCode) const
    144     {
    145         switch (opCode) {
    146         case spv::OpBranchConditional:
    147         case spv::OpBranch:
    148         case spv::OpSwitch:
    149         case spv::OpLoopMerge:
    150         case spv::OpSelectionMerge:
    151         case spv::OpLabel:
    152         case spv::OpFunction:
    153         case spv::OpFunctionEnd:    return true;
    154         default:                    return false;
    155         }
    156     }
    157 
    158     bool spirvbin_t::isTypeOp(spv::Op opCode) const
    159     {
    160         switch (opCode) {
    161         case spv::OpTypeVoid:
    162         case spv::OpTypeBool:
    163         case spv::OpTypeInt:
    164         case spv::OpTypeFloat:
    165         case spv::OpTypeVector:
    166         case spv::OpTypeMatrix:
    167         case spv::OpTypeImage:
    168         case spv::OpTypeSampler:
    169         case spv::OpTypeArray:
    170         case spv::OpTypeRuntimeArray:
    171         case spv::OpTypeStruct:
    172         case spv::OpTypeOpaque:
    173         case spv::OpTypePointer:
    174         case spv::OpTypeFunction:
    175         case spv::OpTypeEvent:
    176         case spv::OpTypeDeviceEvent:
    177         case spv::OpTypeReserveId:
    178         case spv::OpTypeQueue:
    179         case spv::OpTypeSampledImage:
    180         case spv::OpTypePipe:         return true;
    181         default:                      return false;
    182         }
    183     }
    184 
    185     bool spirvbin_t::isConstOp(spv::Op opCode) const
    186     {
    187         switch (opCode) {
    188         case spv::OpConstantNull:       error("unimplemented constant type");
    189         case spv::OpConstantSampler:    error("unimplemented constant type");
    190 
    191         case spv::OpConstantTrue:
    192         case spv::OpConstantFalse:
    193         case spv::OpConstantComposite:
    194         case spv::OpConstant:         return true;
    195         default:                      return false;
    196         }
    197     }
    198 
    199     const auto inst_fn_nop = [](spv::Op, unsigned) { return false; };
    200     const auto op_fn_nop   = [](spv::Id&)          { };
    201 
    202     // g++ doesn't like these defined in the class proper in an anonymous namespace.
    203     // Dunno why.  Also MSVC doesn't like the constexpr keyword.  Also dunno why.
    204     // Defining them externally seems to please both compilers, so, here they are.
    205     const spv::Id spirvbin_t::unmapped    = spv::Id(-10000);
    206     const spv::Id spirvbin_t::unused      = spv::Id(-10001);
    207     const int     spirvbin_t::header_size = 5;
    208 
    209     spv::Id spirvbin_t::nextUnusedId(spv::Id id)
    210     {
    211         while (isNewIdMapped(id))  // search for an unused ID
    212             ++id;
    213 
    214         return id;
    215     }
    216 
    217     spv::Id spirvbin_t::localId(spv::Id id, spv::Id newId)
    218     {
    219         assert(id != spv::NoResult && newId != spv::NoResult);
    220 
    221         if (id >= idMapL.size())
    222             idMapL.resize(id+1, unused);
    223 
    224         if (newId != unmapped && newId != unused) {
    225             if (isOldIdUnused(id))
    226                 error(std::string("ID unused in module: ") + std::to_string(id));
    227 
    228             if (!isOldIdUnmapped(id))
    229                 error(std::string("ID already mapped: ") + std::to_string(id) + " -> "
    230                 + std::to_string(localId(id)));
    231 
    232             if (isNewIdMapped(newId))
    233                 error(std::string("ID already used in module: ") + std::to_string(newId));
    234 
    235             msg(4, 4, std::string("map: ") + std::to_string(id) + " -> " + std::to_string(newId));
    236             setMapped(newId);
    237             largestNewId = std::max(largestNewId, newId);
    238         }
    239 
    240         return idMapL[id] = newId;
    241     }
    242 
    243     // Parse a literal string from the SPIR binary and return it as an std::string
    244     // Due to C++11 RValue references, this doesn't copy the result string.
    245     std::string spirvbin_t::literalString(unsigned word) const
    246     {
    247         std::string literal;
    248 
    249         literal.reserve(16);
    250 
    251         const char* bytes = reinterpret_cast<const char*>(spv.data() + word);
    252 
    253         while (bytes && *bytes)
    254             literal += *bytes++;
    255 
    256         return literal;
    257     }
    258 
    259 
    260     void spirvbin_t::applyMap()
    261     {
    262         msg(3, 2, std::string("Applying map: "));
    263 
    264         // Map local IDs through the ID map
    265         process(inst_fn_nop, // ignore instructions
    266             [this](spv::Id& id) {
    267                 id = localId(id);
    268                 assert(id != unused && id != unmapped);
    269             }
    270         );
    271     }
    272 
    273 
    274     // Find free IDs for anything we haven't mapped
    275     void spirvbin_t::mapRemainder()
    276     {
    277         msg(3, 2, std::string("Remapping remainder: "));
    278 
    279         spv::Id     unusedId  = 1;  // can't use 0: that's NoResult
    280         spirword_t  maxBound  = 0;
    281 
    282         for (spv::Id id = 0; id < idMapL.size(); ++id) {
    283             if (isOldIdUnused(id))
    284                 continue;
    285 
    286             // Find a new mapping for any used but unmapped IDs
    287             if (isOldIdUnmapped(id))
    288                 localId(id, unusedId = nextUnusedId(unusedId));
    289 
    290             if (isOldIdUnmapped(id))
    291                 error(std::string("old ID not mapped: ") + std::to_string(id));
    292 
    293             // Track max bound
    294             maxBound = std::max(maxBound, localId(id) + 1);
    295         }
    296 
    297         bound(maxBound); // reset header ID bound to as big as it now needs to be
    298     }
    299 
    300     void spirvbin_t::stripDebug()
    301     {
    302         if ((options & STRIP) == 0)
    303             return;
    304 
    305         // build local Id and name maps
    306         process(
    307             [&](spv::Op opCode, unsigned start) {
    308                 // remember opcodes we want to strip later
    309                 if (isStripOp(opCode))
    310                     stripInst(start);
    311                 return true;
    312             },
    313             op_fn_nop);
    314     }
    315 
    316     void spirvbin_t::buildLocalMaps()
    317     {
    318         msg(2, 2, std::string("build local maps: "));
    319 
    320         mapped.clear();
    321         idMapL.clear();
    322 //      preserve nameMap, so we don't clear that.
    323         fnPos.clear();
    324         fnPosDCE.clear();
    325         fnCalls.clear();
    326         typeConstPos.clear();
    327         typeConstPosR.clear();
    328         entryPoint = spv::NoResult;
    329         largestNewId = 0;
    330 
    331         idMapL.resize(bound(), unused);
    332 
    333         int         fnStart = 0;
    334         spv::Id     fnRes   = spv::NoResult;
    335 
    336         // build local Id and name maps
    337         process(
    338             [&](spv::Op opCode, unsigned start) {
    339                 // remember opcodes we want to strip later
    340                 if ((options & STRIP) && isStripOp(opCode))
    341                     stripInst(start);
    342 
    343                 if (opCode == spv::Op::OpName) {
    344                     const spv::Id    target = asId(start+1);
    345                     const std::string  name = literalString(start+2);
    346                     nameMap[name] = target;
    347 
    348                 } else if (opCode == spv::Op::OpFunctionCall) {
    349                     ++fnCalls[asId(start + 3)];
    350                 } else if (opCode == spv::Op::OpEntryPoint) {
    351                     entryPoint = asId(start + 2);
    352                 } else if (opCode == spv::Op::OpFunction) {
    353                     if (fnStart != 0)
    354                         error("nested function found");
    355                     fnStart = start;
    356                     fnRes   = asId(start + 2);
    357                 } else if (opCode == spv::Op::OpFunctionEnd) {
    358                     assert(fnRes != spv::NoResult);
    359                     if (fnStart == 0)
    360                         error("function end without function start");
    361                     fnPos[fnRes] = range_t(fnStart, start + asWordCount(start));
    362                     fnStart = 0;
    363                 } else if (isConstOp(opCode)) {
    364                     assert(asId(start + 2) != spv::NoResult);
    365                     typeConstPos.insert(start);
    366                     typeConstPosR[asId(start + 2)] = start;
    367                 } else if (isTypeOp(opCode)) {
    368                     assert(asId(start + 1) != spv::NoResult);
    369                     typeConstPos.insert(start);
    370                     typeConstPosR[asId(start + 1)] = start;
    371                 }
    372 
    373                 return false;
    374             },
    375 
    376             [this](spv::Id& id) { localId(id, unmapped); }
    377         );
    378     }
    379 
    380     // Validate the SPIR header
    381     void spirvbin_t::validate() const
    382     {
    383         msg(2, 2, std::string("validating: "));
    384 
    385         if (spv.size() < header_size)
    386             error("file too short: ");
    387 
    388         if (magic() != spv::MagicNumber)
    389             error("bad magic number");
    390 
    391         // field 1 = version
    392         // field 2 = generator magic
    393         // field 3 = result <id> bound
    394 
    395         if (schemaNum() != 0)
    396             error("bad schema, must be 0");
    397     }
    398 
    399 
    400     int spirvbin_t::processInstruction(unsigned word, instfn_t instFn, idfn_t idFn)
    401     {
    402         const auto     instructionStart = word;
    403         const unsigned wordCount = asWordCount(instructionStart);
    404         const spv::Op  opCode    = asOpCode(instructionStart);
    405         const int      nextInst  = word++ + wordCount;
    406 
    407         if (nextInst > int(spv.size()))
    408             error("spir instruction terminated too early");
    409 
    410         // Base for computing number of operands; will be updated as more is learned
    411         unsigned numOperands = wordCount - 1;
    412 
    413         if (instFn(opCode, instructionStart))
    414             return nextInst;
    415 
    416         // Read type and result ID from instruction desc table
    417         if (spv::InstructionDesc[opCode].hasType()) {
    418             idFn(asId(word++));
    419             --numOperands;
    420         }
    421 
    422         if (spv::InstructionDesc[opCode].hasResult()) {
    423             idFn(asId(word++));
    424             --numOperands;
    425         }
    426 
    427         // Extended instructions: currently, assume everything is an ID.
    428         // TODO: add whatever data we need for exceptions to that
    429         if (opCode == spv::OpExtInst) {
    430             word        += 2; // instruction set, and instruction from set
    431             numOperands -= 2;
    432 
    433             for (unsigned op=0; op < numOperands; ++op)
    434                 idFn(asId(word++)); // ID
    435 
    436             return nextInst;
    437         }
    438 
    439         // Store IDs from instruction in our map
    440         for (int op = 0; numOperands > 0; ++op, --numOperands) {
    441             switch (spv::InstructionDesc[opCode].operands.getClass(op)) {
    442             case spv::OperandId:
    443                 idFn(asId(word++));
    444                 break;
    445 
    446             case spv::OperandVariableIds:
    447                 for (unsigned i = 0; i < numOperands; ++i)
    448                     idFn(asId(word++));
    449                 return nextInst;
    450 
    451             case spv::OperandVariableLiterals:
    452                 // for clarity
    453                 // if (opCode == spv::OpDecorate && asDecoration(word - 1) == spv::DecorationBuiltIn) {
    454                 //     ++word;
    455                 //     --numOperands;
    456                 // }
    457                 // word += numOperands;
    458                 return nextInst;
    459 
    460             case spv::OperandVariableLiteralId:
    461                 while (numOperands > 0) {
    462                     ++word;             // immediate
    463                     idFn(asId(word++)); // ID
    464                     numOperands -= 2;
    465                 }
    466                 return nextInst;
    467 
    468             case spv::OperandLiteralString: {
    469                 const int stringWordCount = literalStringWords(literalString(word));
    470                 word += stringWordCount;
    471                 numOperands -= (stringWordCount-1); // -1 because for() header post-decrements
    472                 break;
    473             }
    474 
    475             // Execution mode might have extra literal operands.  Skip them.
    476             case spv::OperandExecutionMode:
    477                 return nextInst;
    478 
    479             // Single word operands we simply ignore, as they hold no IDs
    480             case spv::OperandLiteralNumber:
    481             case spv::OperandSource:
    482             case spv::OperandExecutionModel:
    483             case spv::OperandAddressing:
    484             case spv::OperandMemory:
    485             case spv::OperandStorage:
    486             case spv::OperandDimensionality:
    487             case spv::OperandSamplerAddressingMode:
    488             case spv::OperandSamplerFilterMode:
    489             case spv::OperandSamplerImageFormat:
    490             case spv::OperandImageChannelOrder:
    491             case spv::OperandImageChannelDataType:
    492             case spv::OperandImageOperands:
    493             case spv::OperandFPFastMath:
    494             case spv::OperandFPRoundingMode:
    495             case spv::OperandLinkageType:
    496             case spv::OperandAccessQualifier:
    497             case spv::OperandFuncParamAttr:
    498             case spv::OperandDecoration:
    499             case spv::OperandBuiltIn:
    500             case spv::OperandSelect:
    501             case spv::OperandLoop:
    502             case spv::OperandFunction:
    503             case spv::OperandMemorySemantics:
    504             case spv::OperandMemoryAccess:
    505             case spv::OperandScope:
    506             case spv::OperandGroupOperation:
    507             case spv::OperandKernelEnqueueFlags:
    508             case spv::OperandKernelProfilingInfo:
    509             case spv::OperandCapability:
    510                 ++word;
    511                 break;
    512 
    513             default:
    514                 assert(0 && "Unhandled Operand Class");
    515                 break;
    516             }
    517         }
    518 
    519         return nextInst;
    520     }
    521 
    522     // Make a pass over all the instructions and process them given appropriate functions
    523     spirvbin_t& spirvbin_t::process(instfn_t instFn, idfn_t idFn, unsigned begin, unsigned end)
    524     {
    525         // For efficiency, reserve name map space.  It can grow if needed.
    526         nameMap.reserve(32);
    527 
    528         // If begin or end == 0, use defaults
    529         begin = (begin == 0 ? header_size          : begin);
    530         end   = (end   == 0 ? unsigned(spv.size()) : end);
    531 
    532         // basic parsing and InstructionDesc table borrowed from SpvDisassemble.cpp...
    533         unsigned nextInst = unsigned(spv.size());
    534 
    535         for (unsigned word = begin; word < end; word = nextInst)
    536             nextInst = processInstruction(word, instFn, idFn);
    537 
    538         return *this;
    539     }
    540 
    541     // Apply global name mapping to a single module
    542     void spirvbin_t::mapNames()
    543     {
    544         static const std::uint32_t softTypeIdLimit = 3011;  // small prime.  TODO: get from options
    545         static const std::uint32_t firstMappedID   = 3019;  // offset into ID space
    546 
    547         for (const auto& name : nameMap) {
    548             std::uint32_t hashval = 1911;
    549             for (const char c : name.first)
    550                 hashval = hashval * 1009 + c;
    551 
    552             if (isOldIdUnmapped(name.second))
    553                 localId(name.second, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
    554         }
    555     }
    556 
    557     // Map fn contents to IDs of similar functions in other modules
    558     void spirvbin_t::mapFnBodies()
    559     {
    560         static const std::uint32_t softTypeIdLimit = 19071;  // small prime.  TODO: get from options
    561         static const std::uint32_t firstMappedID   =  6203;  // offset into ID space
    562 
    563         // Initial approach: go through some high priority opcodes first and assign them
    564         // hash values.
    565 
    566         spv::Id               fnId       = spv::NoResult;
    567         std::vector<unsigned> instPos;
    568         instPos.reserve(unsigned(spv.size()) / 16); // initial estimate; can grow if needed.
    569 
    570         // Build local table of instruction start positions
    571         process(
    572             [&](spv::Op, unsigned start) { instPos.push_back(start); return true; },
    573             op_fn_nop);
    574 
    575         // Window size for context-sensitive canonicalization values
    576         // Empirical best size from a single data set.  TODO: Would be a good tunable.
    577         // We essentially perform a little convolution around each instruction,
    578         // to capture the flavor of nearby code, to hopefully match to similar
    579         // code in other modules.
    580         static const unsigned windowSize = 2;
    581 
    582         for (unsigned entry = 0; entry < unsigned(instPos.size()); ++entry) {
    583             const unsigned start  = instPos[entry];
    584             const spv::Op  opCode = asOpCode(start);
    585 
    586             if (opCode == spv::OpFunction)
    587                 fnId   = asId(start + 2);
    588 
    589             if (opCode == spv::OpFunctionEnd)
    590                 fnId = spv::NoResult;
    591 
    592             if (fnId != spv::NoResult) { // if inside a function
    593                 if (spv::InstructionDesc[opCode].hasResult()) {
    594                     const unsigned word    = start + (spv::InstructionDesc[opCode].hasType() ? 2 : 1);
    595                     const spv::Id  resId   = asId(word);
    596                     std::uint32_t  hashval = fnId * 17; // small prime
    597 
    598                     for (unsigned i = entry-1; i >= entry-windowSize; --i) {
    599                         if (asOpCode(instPos[i]) == spv::OpFunction)
    600                             break;
    601                         hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
    602                     }
    603 
    604                     for (unsigned i = entry; i <= entry + windowSize; ++i) {
    605                         if (asOpCode(instPos[i]) == spv::OpFunctionEnd)
    606                             break;
    607                         hashval = hashval * 30103 + asOpCodeHash(instPos[i]); // 30103 = semiarbitrary prime
    608                     }
    609 
    610                     if (isOldIdUnmapped(resId))
    611                         localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
    612                 }
    613             }
    614         }
    615 
    616         spv::Op          thisOpCode(spv::OpNop);
    617         std::unordered_map<int, int> opCounter;
    618         int              idCounter(0);
    619         fnId = spv::NoResult;
    620 
    621         process(
    622             [&](spv::Op opCode, unsigned start) {
    623                 switch (opCode) {
    624                 case spv::OpFunction:
    625                     // Reset counters at each function
    626                     idCounter = 0;
    627                     opCounter.clear();
    628                     fnId = asId(start + 2);
    629                     break;
    630 
    631                 case spv::OpImageSampleImplicitLod:
    632                 case spv::OpImageSampleExplicitLod:
    633                 case spv::OpImageSampleDrefImplicitLod:
    634                 case spv::OpImageSampleDrefExplicitLod:
    635                 case spv::OpImageSampleProjImplicitLod:
    636                 case spv::OpImageSampleProjExplicitLod:
    637                 case spv::OpImageSampleProjDrefImplicitLod:
    638                 case spv::OpImageSampleProjDrefExplicitLod:
    639                 case spv::OpDot:
    640                 case spv::OpCompositeExtract:
    641                 case spv::OpCompositeInsert:
    642                 case spv::OpVectorShuffle:
    643                 case spv::OpLabel:
    644                 case spv::OpVariable:
    645 
    646                 case spv::OpAccessChain:
    647                 case spv::OpLoad:
    648                 case spv::OpStore:
    649                 case spv::OpCompositeConstruct:
    650                 case spv::OpFunctionCall:
    651                     ++opCounter[opCode];
    652                     idCounter = 0;
    653                     thisOpCode = opCode;
    654                     break;
    655                 default:
    656                     thisOpCode = spv::OpNop;
    657                 }
    658 
    659                 return false;
    660             },
    661 
    662             [&](spv::Id& id) {
    663                 if (thisOpCode != spv::OpNop) {
    664                     ++idCounter;
    665                     const std::uint32_t hashval = opCounter[thisOpCode] * thisOpCode * 50047 + idCounter + fnId * 117;
    666 
    667                     if (isOldIdUnmapped(id))
    668                         localId(id, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
    669                 }
    670             });
    671     }
    672 
    673     // EXPERIMENTAL: forward IO and uniform load/stores into operands
    674     // This produces invalid Schema-0 SPIRV
    675     void spirvbin_t::forwardLoadStores()
    676     {
    677         idset_t fnLocalVars; // set of function local vars
    678         idmap_t idMap;       // Map of load result IDs to what they load
    679 
    680         // EXPERIMENTAL: Forward input and access chain loads into consumptions
    681         process(
    682             [&](spv::Op opCode, unsigned start) {
    683                 // Add inputs and uniforms to the map
    684                 if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&
    685                     (spv[start+3] == spv::StorageClassUniform ||
    686                     spv[start+3] == spv::StorageClassUniformConstant ||
    687                     spv[start+3] == spv::StorageClassInput))
    688                     fnLocalVars.insert(asId(start+2));
    689 
    690                 if (opCode == spv::OpAccessChain && fnLocalVars.count(asId(start+3)) > 0)
    691                     fnLocalVars.insert(asId(start+2));
    692 
    693                 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
    694                     idMap[asId(start+2)] = asId(start+3);
    695                     stripInst(start);
    696                 }
    697 
    698                 return false;
    699             },
    700 
    701             [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
    702         );
    703 
    704         // EXPERIMENTAL: Implicit output stores
    705         fnLocalVars.clear();
    706         idMap.clear();
    707 
    708         process(
    709             [&](spv::Op opCode, unsigned start) {
    710                 // Add inputs and uniforms to the map
    711                 if ((opCode == spv::OpVariable && asWordCount(start) == 4) &&
    712                     (spv[start+3] == spv::StorageClassOutput))
    713                     fnLocalVars.insert(asId(start+2));
    714 
    715                 if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
    716                     idMap[asId(start+2)] = asId(start+1);
    717                     stripInst(start);
    718                 }
    719 
    720                 return false;
    721             },
    722             op_fn_nop);
    723 
    724         process(
    725             inst_fn_nop,
    726             [&](spv::Id& id) { if (idMap.find(id) != idMap.end()) id = idMap[id]; }
    727         );
    728 
    729         strip();          // strip out data we decided to eliminate
    730     }
    731 
    732     // optimize loads and stores
    733     void spirvbin_t::optLoadStore()
    734     {
    735         idset_t    fnLocalVars;  // candidates for removal (only locals)
    736         idmap_t    idMap;        // Map of load result IDs to what they load
    737         blockmap_t blockMap;     // Map of IDs to blocks they first appear in
    738         int        blockNum = 0; // block count, to avoid crossing flow control
    739 
    740         // Find all the function local pointers stored at most once, and not via access chains
    741         process(
    742             [&](spv::Op opCode, unsigned start) {
    743                 const int wordCount = asWordCount(start);
    744 
    745                 // Count blocks, so we can avoid crossing flow control
    746                 if (isFlowCtrl(opCode))
    747                     ++blockNum;
    748 
    749                 // Add local variables to the map
    750                 if ((opCode == spv::OpVariable && spv[start+3] == spv::StorageClassFunction && asWordCount(start) == 4)) {
    751                     fnLocalVars.insert(asId(start+2));
    752                     return true;
    753                 }
    754 
    755                 // Ignore process vars referenced via access chain
    756                 if ((opCode == spv::OpAccessChain || opCode == spv::OpInBoundsAccessChain) && fnLocalVars.count(asId(start+3)) > 0) {
    757                     fnLocalVars.erase(asId(start+3));
    758                     idMap.erase(asId(start+3));
    759                     return true;
    760                 }
    761 
    762                 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0) {
    763                     const spv::Id varId = asId(start+3);
    764 
    765                     // Avoid loads before stores
    766                     if (idMap.find(varId) == idMap.end()) {
    767                         fnLocalVars.erase(varId);
    768                         idMap.erase(varId);
    769                     }
    770 
    771                     // don't do for volatile references
    772                     if (wordCount > 4 && (spv[start+4] & spv::MemoryAccessVolatileMask)) {
    773                         fnLocalVars.erase(varId);
    774                         idMap.erase(varId);
    775                     }
    776 
    777                     // Handle flow control
    778                     if (blockMap.find(varId) == blockMap.end()) {
    779                         blockMap[varId] = blockNum;  // track block we found it in.
    780                     } else if (blockMap[varId] != blockNum) {
    781                         fnLocalVars.erase(varId);  // Ignore if crosses flow control
    782                         idMap.erase(varId);
    783                     }
    784 
    785                     return true;
    786                 }
    787 
    788                 if (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) {
    789                     const spv::Id varId = asId(start+1);
    790 
    791                     if (idMap.find(varId) == idMap.end()) {
    792                         idMap[varId] = asId(start+2);
    793                     } else {
    794                         // Remove if it has more than one store to the same pointer
    795                         fnLocalVars.erase(varId);
    796                         idMap.erase(varId);
    797                     }
    798 
    799                     // don't do for volatile references
    800                     if (wordCount > 3 && (spv[start+3] & spv::MemoryAccessVolatileMask)) {
    801                         fnLocalVars.erase(asId(start+3));
    802                         idMap.erase(asId(start+3));
    803                     }
    804 
    805                     // Handle flow control
    806                     if (blockMap.find(varId) == blockMap.end()) {
    807                         blockMap[varId] = blockNum;  // track block we found it in.
    808                     } else if (blockMap[varId] != blockNum) {
    809                         fnLocalVars.erase(varId);  // Ignore if crosses flow control
    810                         idMap.erase(varId);
    811                     }
    812 
    813                     return true;
    814                 }
    815 
    816                 return false;
    817             },
    818 
    819             // If local var id used anywhere else, don't eliminate
    820             [&](spv::Id& id) {
    821                 if (fnLocalVars.count(id) > 0) {
    822                     fnLocalVars.erase(id);
    823                     idMap.erase(id);
    824                 }
    825             }
    826         );
    827 
    828         process(
    829             [&](spv::Op opCode, unsigned start) {
    830                 if (opCode == spv::OpLoad && fnLocalVars.count(asId(start+3)) > 0)
    831                     idMap[asId(start+2)] = idMap[asId(start+3)];
    832                 return false;
    833             },
    834             op_fn_nop);
    835 
    836         // Chase replacements to their origins, in case there is a chain such as:
    837         //   2 = store 1
    838         //   3 = load 2
    839         //   4 = store 3
    840         //   5 = load 4
    841         // We want to replace uses of 5 with 1.
    842         for (const auto& idPair : idMap) {
    843             spv::Id id = idPair.first;
    844             while (idMap.find(id) != idMap.end())  // Chase to end of chain
    845                 id = idMap[id];
    846 
    847             idMap[idPair.first] = id;              // replace with final result
    848         }
    849 
    850         // Remove the load/store/variables for the ones we've discovered
    851         process(
    852             [&](spv::Op opCode, unsigned start) {
    853                 if ((opCode == spv::OpLoad  && fnLocalVars.count(asId(start+3)) > 0) ||
    854                     (opCode == spv::OpStore && fnLocalVars.count(asId(start+1)) > 0) ||
    855                     (opCode == spv::OpVariable && fnLocalVars.count(asId(start+2)) > 0)) {
    856 
    857                     stripInst(start);
    858                     return true;
    859                 }
    860 
    861                 return false;
    862             },
    863 
    864             [&](spv::Id& id) {
    865                 if (idMap.find(id) != idMap.end()) id = idMap[id];
    866             }
    867         );
    868 
    869         strip();          // strip out data we decided to eliminate
    870     }
    871 
    872     // remove bodies of uncalled functions
    873     void spirvbin_t::dceFuncs()
    874     {
    875         msg(3, 2, std::string("Removing Dead Functions: "));
    876 
    877         // TODO: There are more efficient ways to do this.
    878         bool changed = true;
    879 
    880         while (changed) {
    881             changed = false;
    882 
    883             for (auto fn = fnPos.begin(); fn != fnPos.end(); ) {
    884                 if (fn->first == entryPoint) { // don't DCE away the entry point!
    885                     ++fn;
    886                     continue;
    887                 }
    888 
    889                 const auto call_it = fnCalls.find(fn->first);
    890 
    891                 if (call_it == fnCalls.end() || call_it->second == 0) {
    892                     changed = true;
    893                     stripRange.push_back(fn->second);
    894                     fnPosDCE.insert(*fn);
    895 
    896                     // decrease counts of called functions
    897                     process(
    898                         [&](spv::Op opCode, unsigned start) {
    899                             if (opCode == spv::Op::OpFunctionCall) {
    900                                 const auto call_it = fnCalls.find(asId(start + 3));
    901                                 if (call_it != fnCalls.end()) {
    902                                     if (--call_it->second <= 0)
    903                                         fnCalls.erase(call_it);
    904                                 }
    905                             }
    906 
    907                             return true;
    908                         },
    909                         op_fn_nop,
    910                         fn->second.first,
    911                         fn->second.second);
    912 
    913                     fn = fnPos.erase(fn);
    914                 } else ++fn;
    915             }
    916         }
    917     }
    918 
    919     // remove unused function variables + decorations
    920     void spirvbin_t::dceVars()
    921     {
    922         msg(3, 2, std::string("DCE Vars: "));
    923 
    924         std::unordered_map<spv::Id, int> varUseCount;
    925 
    926         // Count function variable use
    927         process(
    928             [&](spv::Op opCode, unsigned start) {
    929                 if (opCode == spv::OpVariable) {
    930                     ++varUseCount[asId(start+2)];
    931                     return true;
    932                 } else if (opCode == spv::OpEntryPoint) {
    933                     const int wordCount = asWordCount(start);
    934                     for (int i = 4; i < wordCount; i++) {
    935                         ++varUseCount[asId(start+i)];
    936                     }
    937                     return true;
    938                 } else
    939                     return false;
    940             },
    941 
    942             [&](spv::Id& id) { if (varUseCount[id]) ++varUseCount[id]; }
    943         );
    944 
    945         // Remove single-use function variables + associated decorations and names
    946         process(
    947             [&](spv::Op opCode, unsigned start) {
    948                 if ((opCode == spv::OpVariable && varUseCount[asId(start+2)] == 1)  ||
    949                     (opCode == spv::OpDecorate && varUseCount[asId(start+1)] == 1)  ||
    950                     (opCode == spv::OpName     && varUseCount[asId(start+1)] == 1)) {
    951                         stripInst(start);
    952                 }
    953                 return true;
    954             },
    955             op_fn_nop);
    956     }
    957 
    958     // remove unused types
    959     void spirvbin_t::dceTypes()
    960     {
    961         std::vector<bool> isType(bound(), false);
    962 
    963         // for speed, make O(1) way to get to type query (map is log(n))
    964         for (const auto typeStart : typeConstPos)
    965             isType[asTypeConstId(typeStart)] = true;
    966 
    967         std::unordered_map<spv::Id, int> typeUseCount;
    968 
    969         // Count total type usage
    970         process(inst_fn_nop,
    971             [&](spv::Id& id) { if (isType[id]) ++typeUseCount[id]; }
    972         );
    973 
    974         // Remove types from deleted code
    975         for (const auto& fn : fnPosDCE)
    976             process(inst_fn_nop,
    977             [&](spv::Id& id) { if (isType[id]) --typeUseCount[id]; },
    978             fn.second.first, fn.second.second);
    979 
    980         // Remove single reference types
    981         for (const auto typeStart : typeConstPos) {
    982             const spv::Id typeId = asTypeConstId(typeStart);
    983             if (typeUseCount[typeId] == 1) {
    984                 --typeUseCount[typeId];
    985                 stripInst(typeStart);
    986             }
    987         }
    988     }
    989 
    990 
    991 #ifdef NOTDEF
    992     bool spirvbin_t::matchType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt, spv::Id gt) const
    993     {
    994         // Find the local type id "lt" and global type id "gt"
    995         const auto lt_it = typeConstPosR.find(lt);
    996         if (lt_it == typeConstPosR.end())
    997             return false;
    998 
    999         const auto typeStart = lt_it->second;
   1000 
   1001         // Search for entry in global table
   1002         const auto gtype = globalTypes.find(gt);
   1003         if (gtype == globalTypes.end())
   1004             return false;
   1005 
   1006         const auto& gdata = gtype->second;
   1007 
   1008         // local wordcount and opcode
   1009         const int     wordCount   = asWordCount(typeStart);
   1010         const spv::Op opCode      = asOpCode(typeStart);
   1011 
   1012         // no type match if opcodes don't match, or operand count doesn't match
   1013         if (opCode != opOpCode(gdata[0]) || wordCount != opWordCount(gdata[0]))
   1014             return false;
   1015 
   1016         const unsigned numOperands = wordCount - 2; // all types have a result
   1017 
   1018         const auto cmpIdRange = [&](range_t range) {
   1019             for (int x=range.first; x<std::min(range.second, wordCount); ++x)
   1020                 if (!matchType(globalTypes, asId(typeStart+x), gdata[x]))
   1021                     return false;
   1022             return true;
   1023         };
   1024 
   1025         const auto cmpConst   = [&]() { return cmpIdRange(constRange(opCode)); };
   1026         const auto cmpSubType = [&]() { return cmpIdRange(typeRange(opCode));  };
   1027 
   1028         // Compare literals in range [start,end)
   1029         const auto cmpLiteral = [&]() {
   1030             const auto range = literalRange(opCode);
   1031             return std::equal(spir.begin() + typeStart + range.first,
   1032                 spir.begin() + typeStart + std::min(range.second, wordCount),
   1033                 gdata.begin() + range.first);
   1034         };
   1035 
   1036         assert(isTypeOp(opCode) || isConstOp(opCode));
   1037 
   1038         switch (opCode) {
   1039         case spv::OpTypeOpaque:       // TODO: disable until we compare the literal strings.
   1040         case spv::OpTypeQueue:        return false;
   1041         case spv::OpTypeEvent:        // fall through...
   1042         case spv::OpTypeDeviceEvent:  // ...
   1043         case spv::OpTypeReserveId:    return false;
   1044             // for samplers, we don't handle the optional parameters yet
   1045         case spv::OpTypeSampler:      return cmpLiteral() && cmpConst() && cmpSubType() && wordCount == 8;
   1046         default:                      return cmpLiteral() && cmpConst() && cmpSubType();
   1047         }
   1048     }
   1049 
   1050 
   1051     // Look for an equivalent type in the globalTypes map
   1052     spv::Id spirvbin_t::findType(const spirvbin_t::globaltypes_t& globalTypes, spv::Id lt) const
   1053     {
   1054         // Try a recursive type match on each in turn, and return a match if we find one
   1055         for (const auto& gt : globalTypes)
   1056             if (matchType(globalTypes, lt, gt.first))
   1057                 return gt.first;
   1058 
   1059         return spv::NoType;
   1060     }
   1061 #endif // NOTDEF
   1062 
   1063     // Return start position in SPV of given type.  error if not found.
   1064     unsigned spirvbin_t::typePos(spv::Id id) const
   1065     {
   1066         const auto tid_it = typeConstPosR.find(id);
   1067         if (tid_it == typeConstPosR.end())
   1068             error("type ID not found");
   1069 
   1070         return tid_it->second;
   1071     }
   1072 
   1073     // Hash types to canonical values.  This can return ID collisions (it's a bit
   1074     // inevitable): it's up to the caller to handle that gracefully.
   1075     std::uint32_t spirvbin_t::hashType(unsigned typeStart) const
   1076     {
   1077         const unsigned wordCount   = asWordCount(typeStart);
   1078         const spv::Op  opCode      = asOpCode(typeStart);
   1079 
   1080         switch (opCode) {
   1081         case spv::OpTypeVoid:         return 0;
   1082         case spv::OpTypeBool:         return 1;
   1083         case spv::OpTypeInt:          return 3 + (spv[typeStart+3]);
   1084         case spv::OpTypeFloat:        return 5;
   1085         case spv::OpTypeVector:
   1086             return 6 + hashType(typePos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
   1087         case spv::OpTypeMatrix:
   1088             return 30 + hashType(typePos(spv[typeStart+2])) * (spv[typeStart+3] - 1);
   1089         case spv::OpTypeImage:
   1090             return 120 + hashType(typePos(spv[typeStart+2])) +
   1091                 spv[typeStart+3] +            // dimensionality
   1092                 spv[typeStart+4] * 8 * 16 +   // depth
   1093                 spv[typeStart+5] * 4 * 16 +   // arrayed
   1094                 spv[typeStart+6] * 2 * 16 +   // multisampled
   1095                 spv[typeStart+7] * 1 * 16;    // format
   1096         case spv::OpTypeSampler:
   1097             return 500;
   1098         case spv::OpTypeSampledImage:
   1099             return 502;
   1100         case spv::OpTypeArray:
   1101             return 501 + hashType(typePos(spv[typeStart+2])) * spv[typeStart+3];
   1102         case spv::OpTypeRuntimeArray:
   1103             return 5000  + hashType(typePos(spv[typeStart+2]));
   1104         case spv::OpTypeStruct:
   1105             {
   1106                 std::uint32_t hash = 10000;
   1107                 for (unsigned w=2; w < wordCount; ++w)
   1108                     hash += w * hashType(typePos(spv[typeStart+w]));
   1109                 return hash;
   1110             }
   1111 
   1112         case spv::OpTypeOpaque:         return 6000 + spv[typeStart+2];
   1113         case spv::OpTypePointer:        return 100000  + hashType(typePos(spv[typeStart+3]));
   1114         case spv::OpTypeFunction:
   1115             {
   1116                 std::uint32_t hash = 200000;
   1117                 for (unsigned w=2; w < wordCount; ++w)
   1118                     hash += w * hashType(typePos(spv[typeStart+w]));
   1119                 return hash;
   1120             }
   1121 
   1122         case spv::OpTypeEvent:           return 300000;
   1123         case spv::OpTypeDeviceEvent:     return 300001;
   1124         case spv::OpTypeReserveId:       return 300002;
   1125         case spv::OpTypeQueue:           return 300003;
   1126         case spv::OpTypePipe:            return 300004;
   1127 
   1128         case spv::OpConstantNull:        return 300005;
   1129         case spv::OpConstantSampler:     return 300006;
   1130 
   1131         case spv::OpConstantTrue:        return 300007;
   1132         case spv::OpConstantFalse:       return 300008;
   1133         case spv::OpConstantComposite:
   1134             {
   1135                 std::uint32_t hash = 300011 + hashType(typePos(spv[typeStart+1]));
   1136                 for (unsigned w=3; w < wordCount; ++w)
   1137                     hash += w * hashType(typePos(spv[typeStart+w]));
   1138                 return hash;
   1139             }
   1140         case spv::OpConstant:
   1141             {
   1142                 std::uint32_t hash = 400011 + hashType(typePos(spv[typeStart+1]));
   1143                 for (unsigned w=3; w < wordCount; ++w)
   1144                     hash += w * spv[typeStart+w];
   1145                 return hash;
   1146             }
   1147 
   1148         default:
   1149             error("unknown type opcode");
   1150             return 0;
   1151         }
   1152     }
   1153 
   1154     void spirvbin_t::mapTypeConst()
   1155     {
   1156         globaltypes_t globalTypeMap;
   1157 
   1158         msg(3, 2, std::string("Remapping Consts & Types: "));
   1159 
   1160         static const std::uint32_t softTypeIdLimit = 3011; // small prime.  TODO: get from options
   1161         static const std::uint32_t firstMappedID   = 8;    // offset into ID space
   1162 
   1163         for (auto& typeStart : typeConstPos) {
   1164             const spv::Id       resId     = asTypeConstId(typeStart);
   1165             const std::uint32_t hashval   = hashType(typeStart);
   1166 
   1167             if (isOldIdUnmapped(resId))
   1168                 localId(resId, nextUnusedId(hashval % softTypeIdLimit + firstMappedID));
   1169         }
   1170     }
   1171 
   1172 
   1173     // Strip a single binary by removing ranges given in stripRange
   1174     void spirvbin_t::strip()
   1175     {
   1176         if (stripRange.empty()) // nothing to do
   1177             return;
   1178 
   1179         // Sort strip ranges in order of traversal
   1180         std::sort(stripRange.begin(), stripRange.end());
   1181 
   1182         // Allocate a new binary big enough to hold old binary
   1183         // We'll step this iterator through the strip ranges as we go through the binary
   1184         auto strip_it = stripRange.begin();
   1185 
   1186         int strippedPos = 0;
   1187         for (unsigned word = 0; word < unsigned(spv.size()); ++word) {
   1188             if (strip_it != stripRange.end() && word >= strip_it->second)
   1189                 ++strip_it;
   1190 
   1191             if (strip_it == stripRange.end() || word < strip_it->first || word >= strip_it->second)
   1192                 spv[strippedPos++] = spv[word];
   1193         }
   1194 
   1195         spv.resize(strippedPos);
   1196         stripRange.clear();
   1197 
   1198         buildLocalMaps();
   1199     }
   1200 
   1201     // Strip a single binary by removing ranges given in stripRange
   1202     void spirvbin_t::remap(std::uint32_t opts)
   1203     {
   1204         options = opts;
   1205 
   1206         // Set up opcode tables from SpvDoc
   1207         spv::Parameterize();
   1208 
   1209         validate();  // validate header
   1210         buildLocalMaps();
   1211 
   1212         msg(3, 4, std::string("ID bound: ") + std::to_string(bound()));
   1213 
   1214         strip();        // strip out data we decided to eliminate
   1215 
   1216         if (options & OPT_LOADSTORE) optLoadStore();
   1217         if (options & OPT_FWD_LS)    forwardLoadStores();
   1218         if (options & DCE_FUNCS)     dceFuncs();
   1219         if (options & DCE_VARS)      dceVars();
   1220         if (options & DCE_TYPES)     dceTypes();
   1221         if (options & MAP_TYPES)     mapTypeConst();
   1222         if (options & MAP_NAMES)     mapNames();
   1223         if (options & MAP_FUNCS)     mapFnBodies();
   1224 
   1225         mapRemainder(); // map any unmapped IDs
   1226         applyMap();     // Now remap each shader to the new IDs we've come up with
   1227         strip();        // strip out data we decided to eliminate
   1228     }
   1229 
   1230     // remap from a memory image
   1231     void spirvbin_t::remap(std::vector<std::uint32_t>& in_spv, std::uint32_t opts)
   1232     {
   1233         spv.swap(in_spv);
   1234         remap(opts);
   1235         spv.swap(in_spv);
   1236     }
   1237 
   1238 } // namespace SPV
   1239 
   1240 #endif // defined (use_cpp11)
   1241 
   1242