Home | History | Annotate | Download | only in SPIRV
      1 //
      2 // Copyright (C) 2018 Google, Inc.
      3 //
      4 // All rights reserved.
      5 //
      6 // Redistribution and use in source and binary forms, with or without
      7 // modification, are permitted provided that the following conditions
      8 // are met:
      9 //
     10 //    Redistributions of source code must retain the above copyright
     11 //    notice, this list of conditions and the following disclaimer.
     12 //
     13 //    Redistributions in binary form must reproduce the above
     14 //    copyright notice, this list of conditions and the following
     15 //    disclaimer in the documentation and/or other materials provided
     16 //    with the distribution.
     17 //
     18 //    Neither the name of 3Dlabs Inc. Ltd. nor the names of its
     19 //    contributors may be used to endorse or promote products derived
     20 //    from this software without specific prior written permission.
     21 //
     22 // THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS
     23 // "AS IS" AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT
     24 // LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS
     25 // FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE
     26 // COPYRIGHT HOLDERS OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT,
     27 // INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
     28 // BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
     29 // LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
     30 // CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT
     31 // LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN
     32 // ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE
     33 // POSSIBILITY OF SUCH DAMAGE.
     34 
     35 //
     36 // Post-processing for SPIR-V IR, in internal form, not standard binary form.
     37 //
     38 
     39 #include <cassert>
     40 #include <cstdlib>
     41 
     42 #include <unordered_set>
     43 #include <algorithm>
     44 
     45 #include "SpvBuilder.h"
     46 
     47 #include "spirv.hpp"
     48 #include "GlslangToSpv.h"
     49 #include "SpvBuilder.h"
     50 namespace spv {
     51     #include "GLSL.std.450.h"
     52     #include "GLSL.ext.KHR.h"
     53     #include "GLSL.ext.EXT.h"
     54 #ifdef AMD_EXTENSIONS
     55     #include "GLSL.ext.AMD.h"
     56 #endif
     57 #ifdef NV_EXTENSIONS
     58     #include "GLSL.ext.NV.h"
     59 #endif
     60 }
     61 
     62 namespace spv {
     63 
     64 // Hook to visit each operand type and result type of an instruction.
     65 // Will be called multiple times for one instruction, once for each typed
     66 // operand and the result.
     67 void Builder::postProcessType(const Instruction& inst, Id typeId)
     68 {
     69     // Characterize the type being questioned
     70     Id basicTypeOp = getMostBasicTypeClass(typeId);
     71     int width = 0;
     72     if (basicTypeOp == OpTypeFloat || basicTypeOp == OpTypeInt)
     73         width = getScalarTypeWidth(typeId);
     74 
     75     // Do opcode-specific checks
     76     switch (inst.getOpCode()) {
     77     case OpLoad:
     78     case OpStore:
     79         if (basicTypeOp == OpTypeStruct) {
     80             if (containsType(typeId, OpTypeInt, 8))
     81                 addCapability(CapabilityInt8);
     82             if (containsType(typeId, OpTypeInt, 16))
     83                 addCapability(CapabilityInt16);
     84             if (containsType(typeId, OpTypeFloat, 16))
     85                 addCapability(CapabilityFloat16);
     86         } else {
     87             StorageClass storageClass = getStorageClass(inst.getIdOperand(0));
     88             if (width == 8) {
     89                 switch (storageClass) {
     90                 case StorageClassPhysicalStorageBufferEXT:
     91                 case StorageClassUniform:
     92                 case StorageClassStorageBuffer:
     93                 case StorageClassPushConstant:
     94                     break;
     95                 default:
     96                     addCapability(CapabilityInt8);
     97                     break;
     98                 }
     99             } else if (width == 16) {
    100                 switch (storageClass) {
    101                 case StorageClassPhysicalStorageBufferEXT:
    102                 case StorageClassUniform:
    103                 case StorageClassStorageBuffer:
    104                 case StorageClassPushConstant:
    105                 case StorageClassInput:
    106                 case StorageClassOutput:
    107                     break;
    108                 default:
    109                     if (basicTypeOp == OpTypeInt)
    110                         addCapability(CapabilityInt16);
    111                     if (basicTypeOp == OpTypeFloat)
    112                         addCapability(CapabilityFloat16);
    113                     break;
    114                 }
    115             }
    116         }
    117         break;
    118     case OpAccessChain:
    119     case OpPtrAccessChain:
    120     case OpCopyObject:
    121     case OpFConvert:
    122     case OpSConvert:
    123     case OpUConvert:
    124         break;
    125     case OpExtInst:
    126 #if AMD_EXTENSIONS
    127         switch (inst.getImmediateOperand(1)) {
    128         case GLSLstd450Frexp:
    129         case GLSLstd450FrexpStruct:
    130             if (getSpvVersion() < glslang::EShTargetSpv_1_3 && containsType(typeId, OpTypeInt, 16))
    131                 addExtension(spv::E_SPV_AMD_gpu_shader_int16);
    132             break;
    133         case GLSLstd450InterpolateAtCentroid:
    134         case GLSLstd450InterpolateAtSample:
    135         case GLSLstd450InterpolateAtOffset:
    136             if (getSpvVersion() < glslang::EShTargetSpv_1_3 && containsType(typeId, OpTypeFloat, 16))
    137                 addExtension(spv::E_SPV_AMD_gpu_shader_half_float);
    138             break;
    139         default:
    140             break;
    141         }
    142 #endif
    143         break;
    144     default:
    145         if (basicTypeOp == OpTypeFloat && width == 16)
    146             addCapability(CapabilityFloat16);
    147         if (basicTypeOp == OpTypeInt && width == 16)
    148             addCapability(CapabilityInt16);
    149         if (basicTypeOp == OpTypeInt && width == 8)
    150             addCapability(CapabilityInt8);
    151         break;
    152     }
    153 }
    154 
    155 // Called for each instruction that resides in a block.
    156 void Builder::postProcess(Instruction& inst)
    157 {
    158     // Add capabilities based simply on the opcode.
    159     switch (inst.getOpCode()) {
    160     case OpExtInst:
    161         switch (inst.getImmediateOperand(1)) {
    162         case GLSLstd450InterpolateAtCentroid:
    163         case GLSLstd450InterpolateAtSample:
    164         case GLSLstd450InterpolateAtOffset:
    165             addCapability(CapabilityInterpolationFunction);
    166             break;
    167         default:
    168             break;
    169         }
    170         break;
    171     case OpDPdxFine:
    172     case OpDPdyFine:
    173     case OpFwidthFine:
    174     case OpDPdxCoarse:
    175     case OpDPdyCoarse:
    176     case OpFwidthCoarse:
    177         addCapability(CapabilityDerivativeControl);
    178         break;
    179 
    180     case OpImageQueryLod:
    181     case OpImageQuerySize:
    182     case OpImageQuerySizeLod:
    183     case OpImageQuerySamples:
    184     case OpImageQueryLevels:
    185         addCapability(CapabilityImageQuery);
    186         break;
    187 
    188 #ifdef NV_EXTENSIONS
    189     case OpGroupNonUniformPartitionNV:
    190         addExtension(E_SPV_NV_shader_subgroup_partitioned);
    191         addCapability(CapabilityGroupNonUniformPartitionedNV);
    192         break;
    193 #endif
    194 
    195     case OpLoad:
    196     case OpStore:
    197         {
    198             // For any load/store to a PhysicalStorageBufferEXT, walk the accesschain
    199             // index list to compute the misalignment. The pre-existing alignment value
    200             // (set via Builder::AccessChain::alignment) only accounts for the base of
    201             // the reference type and any scalar component selection in the accesschain,
    202             // and this function computes the rest from the SPIR-V Offset decorations.
    203             Instruction *accessChain = module.getInstruction(inst.getIdOperand(0));
    204             if (accessChain->getOpCode() == OpAccessChain) {
    205                 Instruction *base = module.getInstruction(accessChain->getIdOperand(0));
    206                 // Get the type of the base of the access chain. It must be a pointer type.
    207                 Id typeId = base->getTypeId();
    208                 Instruction *type = module.getInstruction(typeId);
    209                 assert(type->getOpCode() == OpTypePointer);
    210                 if (type->getImmediateOperand(0) != StorageClassPhysicalStorageBufferEXT) {
    211                     break;
    212                 }
    213                 // Get the pointee type.
    214                 typeId = type->getIdOperand(1);
    215                 type = module.getInstruction(typeId);
    216                 // Walk the index list for the access chain. For each index, find any
    217                 // misalignment that can apply when accessing the member/element via
    218                 // Offset/ArrayStride/MatrixStride decorations, and bitwise OR them all
    219                 // together.
    220                 int alignment = 0;
    221                 for (int i = 1; i < accessChain->getNumOperands(); ++i) {
    222                     Instruction *idx = module.getInstruction(accessChain->getIdOperand(i));
    223                     if (type->getOpCode() == OpTypeStruct) {
    224                         assert(idx->getOpCode() == OpConstant);
    225                         int c = idx->getImmediateOperand(0);
    226 
    227                         const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
    228                             if (decoration.get()->getOpCode() == OpMemberDecorate &&
    229                                 decoration.get()->getIdOperand(0) == typeId &&
    230                                 decoration.get()->getImmediateOperand(1) == c &&
    231                                 (decoration.get()->getImmediateOperand(2) == DecorationOffset ||
    232                                  decoration.get()->getImmediateOperand(2) == DecorationMatrixStride)) {
    233                                 alignment |= decoration.get()->getImmediateOperand(3);
    234                             }
    235                         };
    236                         std::for_each(decorations.begin(), decorations.end(), function);
    237                         // get the next member type
    238                         typeId = type->getIdOperand(c);
    239                         type = module.getInstruction(typeId);
    240                     } else if (type->getOpCode() == OpTypeArray ||
    241                                type->getOpCode() == OpTypeRuntimeArray) {
    242                         const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
    243                             if (decoration.get()->getOpCode() == OpDecorate &&
    244                                 decoration.get()->getIdOperand(0) == typeId &&
    245                                 decoration.get()->getImmediateOperand(1) == DecorationArrayStride) {
    246                                 alignment |= decoration.get()->getImmediateOperand(2);
    247                             }
    248                         };
    249                         std::for_each(decorations.begin(), decorations.end(), function);
    250                         // Get the element type
    251                         typeId = type->getIdOperand(0);
    252                         type = module.getInstruction(typeId);
    253                     } else {
    254                         // Once we get to any non-aggregate type, we're done.
    255                         break;
    256                     }
    257                 }
    258                 assert(inst.getNumOperands() >= 3);
    259                 unsigned int memoryAccess = inst.getImmediateOperand((inst.getOpCode() == OpStore) ? 2 : 1);
    260                 assert(memoryAccess & MemoryAccessAlignedMask);
    261                 // Compute the index of the alignment operand.
    262                 int alignmentIdx = 2;
    263                 if (memoryAccess & MemoryAccessVolatileMask)
    264                     alignmentIdx++;
    265                 if (inst.getOpCode() == OpStore)
    266                     alignmentIdx++;
    267                 // Merge new and old (mis)alignment
    268                 alignment |= inst.getImmediateOperand(alignmentIdx);
    269                 // Pick the LSB
    270                 alignment = alignment & ~(alignment & (alignment-1));
    271                 // update the Aligned operand
    272                 inst.setImmediateOperand(alignmentIdx, alignment);
    273             }
    274             break;
    275         }
    276 
    277     default:
    278         break;
    279     }
    280 
    281     // Checks based on type
    282     if (inst.getTypeId() != NoType)
    283         postProcessType(inst, inst.getTypeId());
    284     for (int op = 0; op < inst.getNumOperands(); ++op) {
    285         if (inst.isIdOperand(op)) {
    286             // In blocks, these are always result ids, but we are relying on
    287             // getTypeId() to return NoType for things like OpLabel.
    288             if (getTypeId(inst.getIdOperand(op)) != NoType)
    289                 postProcessType(inst, getTypeId(inst.getIdOperand(op)));
    290         }
    291     }
    292 }
    293 
    294 // Called for each instruction in a reachable block.
    295 void Builder::postProcessReachable(const Instruction&)
    296 {
    297     // did have code here, but questionable to do so without deleting the instructions
    298 }
    299 
    300 // comment in header
    301 void Builder::postProcess()
    302 {
    303     std::unordered_set<const Block*> reachableBlocks;
    304     std::unordered_set<Id> unreachableDefinitions;
    305     // Collect IDs defined in unreachable blocks. For each function, label the
    306     // reachable blocks first. Then for each unreachable block, collect the
    307     // result IDs of the instructions in it.
    308     for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
    309         Function* f = *fi;
    310         Block* entry = f->getEntryBlock();
    311         inReadableOrder(entry, [&reachableBlocks](const Block* b) { reachableBlocks.insert(b); });
    312         for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
    313             Block* b = *bi;
    314             if (reachableBlocks.count(b) == 0) {
    315                 for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
    316                     unreachableDefinitions.insert(ii->get()->getResultId());
    317             }
    318         }
    319     }
    320 
    321     // Remove unneeded decorations, for unreachable instructions
    322     decorations.erase(std::remove_if(decorations.begin(), decorations.end(),
    323         [&unreachableDefinitions](std::unique_ptr<Instruction>& I) -> bool {
    324             Id decoration_id = I.get()->getIdOperand(0);
    325             return unreachableDefinitions.count(decoration_id) != 0;
    326         }),
    327         decorations.end());
    328 
    329     // Add per-instruction capabilities, extensions, etc.,
    330 
    331     // process all reachable instructions...
    332     for (auto bi = reachableBlocks.cbegin(); bi != reachableBlocks.cend(); ++bi) {
    333         const Block* block = *bi;
    334         const auto function = [this](const std::unique_ptr<Instruction>& inst) { postProcessReachable(*inst.get()); };
    335         std::for_each(block->getInstructions().begin(), block->getInstructions().end(), function);
    336     }
    337 
    338     // process all block-contained instructions
    339     for (auto fi = module.getFunctions().cbegin(); fi != module.getFunctions().cend(); fi++) {
    340         Function* f = *fi;
    341         for (auto bi = f->getBlocks().cbegin(); bi != f->getBlocks().cend(); bi++) {
    342             Block* b = *bi;
    343             for (auto ii = b->getInstructions().cbegin(); ii != b->getInstructions().cend(); ii++)
    344                 postProcess(*ii->get());
    345 
    346             // For all local variables that contain pointers to PhysicalStorageBufferEXT, check whether
    347             // there is an existing restrict/aliased decoration. If we don't find one, add Aliased as the
    348             // default.
    349             for (auto vi = b->getLocalVariables().cbegin(); vi != b->getLocalVariables().cend(); vi++) {
    350                 const Instruction& inst = *vi->get();
    351                 Id resultId = inst.getResultId();
    352                 if (containsPhysicalStorageBufferOrArray(getDerefTypeId(resultId))) {
    353                     bool foundDecoration = false;
    354                     const auto function = [&](const std::unique_ptr<Instruction>& decoration) {
    355                         if (decoration.get()->getIdOperand(0) == resultId &&
    356                             decoration.get()->getOpCode() == OpDecorate &&
    357                             (decoration.get()->getImmediateOperand(1) == spv::DecorationAliasedPointerEXT ||
    358                              decoration.get()->getImmediateOperand(1) == spv::DecorationRestrictPointerEXT)) {
    359                             foundDecoration = true;
    360                         }
    361                     };
    362                     std::for_each(decorations.begin(), decorations.end(), function);
    363                     if (!foundDecoration) {
    364                         addDecoration(resultId, spv::DecorationAliasedPointerEXT);
    365                     }
    366                 }
    367             }
    368         }
    369     }
    370 
    371     // Look for any 8/16 bit type in physical storage buffer class, and set the
    372     // appropriate capability. This happens in createSpvVariable for other storage
    373     // classes, but there isn't always a variable for physical storage buffer.
    374     for (int t = 0; t < (int)groupedTypes[OpTypePointer].size(); ++t) {
    375         Instruction* type = groupedTypes[OpTypePointer][t];
    376         if (type->getImmediateOperand(0) == (unsigned)StorageClassPhysicalStorageBufferEXT) {
    377             if (containsType(type->getIdOperand(1), OpTypeInt, 8)) {
    378                 addExtension(spv::E_SPV_KHR_8bit_storage);
    379                 addCapability(spv::CapabilityStorageBuffer8BitAccess);
    380             }
    381             if (containsType(type->getIdOperand(1), OpTypeInt, 16) ||
    382                 containsType(type->getIdOperand(1), OpTypeFloat, 16)) {
    383                 addExtension(spv::E_SPV_KHR_16bit_storage);
    384                 addCapability(spv::CapabilityStorageBuffer16BitAccess);
    385             }
    386         }
    387     }
    388 }
    389 
    390 }; // end spv namespace
    391