Home | History | Annotate | Download | only in opt
      1 // Copyright (c) 2018 The Khronos Group Inc.
      2 // Copyright (c) 2018 Valve Corporation
      3 // Copyright (c) 2018 LunarG Inc.
      4 //
      5 // Licensed under the Apache License, Version 2.0 (the "License");
      6 // you may not use this file except in compliance with the License.
      7 // You may obtain a copy of the License at
      8 //
      9 //     http://www.apache.org/licenses/LICENSE-2.0
     10 //
     11 // Unless required by applicable law or agreed to in writing, software
     12 // distributed under the License is distributed on an "AS IS" BASIS,
     13 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     14 // See the License for the specific language governing permissions and
     15 // limitations under the License.
     16 
     17 #include "inst_bindless_check_pass.h"
     18 
     19 namespace {
     20 
     21 // Input Operand Indices
     22 static const int kSpvImageSampleImageIdInIdx = 0;
     23 static const int kSpvSampledImageImageIdInIdx = 0;
     24 static const int kSpvSampledImageSamplerIdInIdx = 1;
     25 static const int kSpvImageSampledImageIdInIdx = 0;
     26 static const int kSpvLoadPtrIdInIdx = 0;
     27 static const int kSpvAccessChainBaseIdInIdx = 0;
     28 static const int kSpvAccessChainIndex0IdInIdx = 1;
     29 static const int kSpvTypePointerTypeIdInIdx = 1;
     30 static const int kSpvTypeArrayLengthIdInIdx = 1;
     31 static const int kSpvConstantValueInIdx = 0;
     32 
     33 }  // anonymous namespace
     34 
     35 namespace spvtools {
     36 namespace opt {
     37 
     38 void InstBindlessCheckPass::GenBindlessCheckCode(
     39     BasicBlock::iterator ref_inst_itr,
     40     UptrVectorIterator<BasicBlock> ref_block_itr, uint32_t instruction_idx,
     41     uint32_t stage_idx, std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
     42   // Look for reference through bindless descriptor. If not, return.
     43   std::unique_ptr<BasicBlock> new_blk_ptr;
     44   uint32_t image_id;
     45   switch (ref_inst_itr->opcode()) {
     46     case SpvOp::SpvOpImageSampleImplicitLod:
     47     case SpvOp::SpvOpImageSampleExplicitLod:
     48     case SpvOp::SpvOpImageSampleDrefImplicitLod:
     49     case SpvOp::SpvOpImageSampleDrefExplicitLod:
     50     case SpvOp::SpvOpImageSampleProjImplicitLod:
     51     case SpvOp::SpvOpImageSampleProjExplicitLod:
     52     case SpvOp::SpvOpImageSampleProjDrefImplicitLod:
     53     case SpvOp::SpvOpImageSampleProjDrefExplicitLod:
     54     case SpvOp::SpvOpImageGather:
     55     case SpvOp::SpvOpImageDrefGather:
     56     case SpvOp::SpvOpImageQueryLod:
     57     case SpvOp::SpvOpImageSparseSampleImplicitLod:
     58     case SpvOp::SpvOpImageSparseSampleExplicitLod:
     59     case SpvOp::SpvOpImageSparseSampleDrefImplicitLod:
     60     case SpvOp::SpvOpImageSparseSampleDrefExplicitLod:
     61     case SpvOp::SpvOpImageSparseSampleProjImplicitLod:
     62     case SpvOp::SpvOpImageSparseSampleProjExplicitLod:
     63     case SpvOp::SpvOpImageSparseSampleProjDrefImplicitLod:
     64     case SpvOp::SpvOpImageSparseSampleProjDrefExplicitLod:
     65     case SpvOp::SpvOpImageSparseGather:
     66     case SpvOp::SpvOpImageSparseDrefGather:
     67     case SpvOp::SpvOpImageFetch:
     68     case SpvOp::SpvOpImageRead:
     69     case SpvOp::SpvOpImageQueryFormat:
     70     case SpvOp::SpvOpImageQueryOrder:
     71     case SpvOp::SpvOpImageQuerySizeLod:
     72     case SpvOp::SpvOpImageQuerySize:
     73     case SpvOp::SpvOpImageQueryLevels:
     74     case SpvOp::SpvOpImageQuerySamples:
     75     case SpvOp::SpvOpImageSparseFetch:
     76     case SpvOp::SpvOpImageSparseRead:
     77     case SpvOp::SpvOpImageWrite:
     78       image_id =
     79           ref_inst_itr->GetSingleWordInOperand(kSpvImageSampleImageIdInIdx);
     80       break;
     81     default:
     82       return;
     83   }
     84   Instruction* image_inst = get_def_use_mgr()->GetDef(image_id);
     85   uint32_t load_id;
     86   Instruction* load_inst;
     87   if (image_inst->opcode() == SpvOp::SpvOpSampledImage) {
     88     load_id = image_inst->GetSingleWordInOperand(kSpvSampledImageImageIdInIdx);
     89     load_inst = get_def_use_mgr()->GetDef(load_id);
     90   } else if (image_inst->opcode() == SpvOp::SpvOpImage) {
     91     load_id = image_inst->GetSingleWordInOperand(kSpvImageSampledImageIdInIdx);
     92     load_inst = get_def_use_mgr()->GetDef(load_id);
     93   } else {
     94     load_id = image_id;
     95     load_inst = image_inst;
     96     image_id = 0;
     97   }
     98   if (load_inst->opcode() != SpvOp::SpvOpLoad) {
     99     // TODO(greg-lunarg): Handle additional possibilities
    100     return;
    101   }
    102   uint32_t ptr_id = load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx);
    103   Instruction* ptr_inst = get_def_use_mgr()->GetDef(ptr_id);
    104   if (ptr_inst->opcode() != SpvOp::SpvOpAccessChain) return;
    105   if (ptr_inst->NumInOperands() != 2) {
    106     assert(false && "unexpected bindless index number");
    107     return;
    108   }
    109   uint32_t index_id =
    110       ptr_inst->GetSingleWordInOperand(kSpvAccessChainIndex0IdInIdx);
    111   ptr_id = ptr_inst->GetSingleWordInOperand(kSpvAccessChainBaseIdInIdx);
    112   ptr_inst = get_def_use_mgr()->GetDef(ptr_id);
    113   if (ptr_inst->opcode() != SpvOpVariable) {
    114     assert(false && "unexpected bindless base");
    115     return;
    116   }
    117   uint32_t var_type_id = ptr_inst->type_id();
    118   Instruction* var_type_inst = get_def_use_mgr()->GetDef(var_type_id);
    119   uint32_t ptr_type_id =
    120       var_type_inst->GetSingleWordInOperand(kSpvTypePointerTypeIdInIdx);
    121   Instruction* ptr_type_inst = get_def_use_mgr()->GetDef(ptr_type_id);
    122   // TODO(greg-lunarg): Handle RuntimeArray. Will need to pull length
    123   // out of debug input buffer.
    124   if (ptr_type_inst->opcode() != SpvOpTypeArray) return;
    125   // If index and bound both compile-time constants and index < bound,
    126   // return without changing
    127   uint32_t length_id =
    128       ptr_type_inst->GetSingleWordInOperand(kSpvTypeArrayLengthIdInIdx);
    129   Instruction* index_inst = get_def_use_mgr()->GetDef(index_id);
    130   Instruction* length_inst = get_def_use_mgr()->GetDef(length_id);
    131   if (index_inst->opcode() == SpvOpConstant &&
    132       length_inst->opcode() == SpvOpConstant &&
    133       index_inst->GetSingleWordInOperand(kSpvConstantValueInIdx) <
    134           length_inst->GetSingleWordInOperand(kSpvConstantValueInIdx))
    135     return;
    136   // Generate full runtime bounds test code with true branch
    137   // being full reference and false branch being debug output and zero
    138   // for the referenced value.
    139   MovePreludeCode(ref_inst_itr, ref_block_itr, &new_blk_ptr);
    140   InstructionBuilder builder(
    141       context(), &*new_blk_ptr,
    142       IRContext::kAnalysisDefUse | IRContext::kAnalysisInstrToBlockMapping);
    143   uint32_t error_id = builder.GetUintConstantId(kInstErrorBindlessBounds);
    144   Instruction* ult_inst =
    145       builder.AddBinaryOp(GetBoolId(), SpvOpULessThan, index_id, length_id);
    146   uint32_t merge_blk_id = TakeNextId();
    147   uint32_t valid_blk_id = TakeNextId();
    148   uint32_t invalid_blk_id = TakeNextId();
    149   std::unique_ptr<Instruction> merge_label(NewLabel(merge_blk_id));
    150   std::unique_ptr<Instruction> valid_label(NewLabel(valid_blk_id));
    151   std::unique_ptr<Instruction> invalid_label(NewLabel(invalid_blk_id));
    152   (void)builder.AddConditionalBranch(ult_inst->result_id(), valid_blk_id,
    153                                      invalid_blk_id, merge_blk_id,
    154                                      SpvSelectionControlMaskNone);
    155   // Close selection block and gen valid reference block
    156   new_blocks->push_back(std::move(new_blk_ptr));
    157   new_blk_ptr.reset(new BasicBlock(std::move(valid_label)));
    158   builder.SetInsertPoint(&*new_blk_ptr);
    159   // Clone descriptor load
    160   Instruction* new_load_inst =
    161       builder.AddLoad(load_inst->type_id(),
    162                       load_inst->GetSingleWordInOperand(kSpvLoadPtrIdInIdx));
    163   uint32_t new_load_id = new_load_inst->result_id();
    164   get_decoration_mgr()->CloneDecorations(load_inst->result_id(), new_load_id);
    165   uint32_t new_image_id = new_load_id;
    166   // Clone Image/SampledImage with new load, if needed
    167   if (image_id != 0) {
    168     if (image_inst->opcode() == SpvOp::SpvOpSampledImage) {
    169       Instruction* new_image_inst = builder.AddBinaryOp(
    170           image_inst->type_id(), SpvOpSampledImage, new_load_id,
    171           image_inst->GetSingleWordInOperand(kSpvSampledImageSamplerIdInIdx));
    172       new_image_id = new_image_inst->result_id();
    173     } else {
    174       assert(image_inst->opcode() == SpvOp::SpvOpImage && "expecting OpImage");
    175       Instruction* new_image_inst =
    176           builder.AddUnaryOp(image_inst->type_id(), SpvOpImage, new_load_id);
    177       new_image_id = new_image_inst->result_id();
    178     }
    179     get_decoration_mgr()->CloneDecorations(image_id, new_image_id);
    180   }
    181   // Clone original reference using new image code
    182   std::unique_ptr<Instruction> new_ref_inst(ref_inst_itr->Clone(context()));
    183   uint32_t ref_result_id = ref_inst_itr->result_id();
    184   uint32_t new_ref_id = 0;
    185   if (ref_result_id != 0) {
    186     new_ref_id = TakeNextId();
    187     new_ref_inst->SetResultId(new_ref_id);
    188   }
    189   new_ref_inst->SetInOperand(kSpvImageSampleImageIdInIdx, {new_image_id});
    190   // Register new reference and add to new block
    191   builder.AddInstruction(std::move(new_ref_inst));
    192   if (new_ref_id != 0)
    193     get_decoration_mgr()->CloneDecorations(ref_result_id, new_ref_id);
    194   // Close valid block and gen invalid block
    195   (void)builder.AddBranch(merge_blk_id);
    196   new_blocks->push_back(std::move(new_blk_ptr));
    197   new_blk_ptr.reset(new BasicBlock(std::move(invalid_label)));
    198   builder.SetInsertPoint(&*new_blk_ptr);
    199   uint32_t u_index_id = GenUintCastCode(index_id, &builder);
    200   GenDebugStreamWrite(instruction_idx, stage_idx,
    201                       {error_id, u_index_id, length_id}, &builder);
    202   // Remember last invalid block id
    203   uint32_t last_invalid_blk_id = new_blk_ptr->GetLabelInst()->result_id();
    204   // Gen zero for invalid  reference
    205   uint32_t ref_type_id = ref_inst_itr->type_id();
    206   // Close invalid block and gen merge block
    207   (void)builder.AddBranch(merge_blk_id);
    208   new_blocks->push_back(std::move(new_blk_ptr));
    209   new_blk_ptr.reset(new BasicBlock(std::move(merge_label)));
    210   builder.SetInsertPoint(&*new_blk_ptr);
    211   // Gen phi of new reference and zero, if necessary, and replace the
    212   // result id of the original reference with that of the Phi. Kill original
    213   // reference and move in remainder of original block.
    214   if (new_ref_id != 0) {
    215     Instruction* phi_inst = builder.AddPhi(
    216         ref_type_id, {new_ref_id, valid_blk_id, builder.GetNullId(ref_type_id),
    217                       last_invalid_blk_id});
    218     context()->ReplaceAllUsesWith(ref_result_id, phi_inst->result_id());
    219   }
    220   context()->KillInst(&*ref_inst_itr);
    221   MovePostludeCode(ref_block_itr, &new_blk_ptr);
    222   // Add remainder/merge block to new blocks
    223   new_blocks->push_back(std::move(new_blk_ptr));
    224 }
    225 
    226 void InstBindlessCheckPass::InitializeInstBindlessCheck() {
    227   // Initialize base class
    228   InitializeInstrument();
    229   // Look for related extensions
    230   ext_descriptor_indexing_defined_ = false;
    231   for (auto& ei : get_module()->extensions()) {
    232     const char* ext_name =
    233         reinterpret_cast<const char*>(&ei.GetInOperand(0).words[0]);
    234     if (strcmp(ext_name, "SPV_EXT_descriptor_indexing") == 0) {
    235       ext_descriptor_indexing_defined_ = true;
    236       break;
    237     }
    238   }
    239 }
    240 
    241 Pass::Status InstBindlessCheckPass::ProcessImpl() {
    242   // Perform instrumentation on each entry point function in module
    243   InstProcessFunction pfn =
    244       [this](BasicBlock::iterator ref_inst_itr,
    245              UptrVectorIterator<BasicBlock> ref_block_itr,
    246              uint32_t instruction_idx, uint32_t stage_idx,
    247              std::vector<std::unique_ptr<BasicBlock>>* new_blocks) {
    248         return GenBindlessCheckCode(ref_inst_itr, ref_block_itr,
    249                                     instruction_idx, stage_idx, new_blocks);
    250       };
    251   bool modified = InstProcessEntryPointCallTree(pfn);
    252   // This pass does not update inst->blk info
    253   context()->InvalidateAnalyses(IRContext::kAnalysisInstrToBlockMapping);
    254   return modified ? Status::SuccessWithChange : Status::SuccessWithoutChange;
    255 }
    256 
    257 Pass::Status InstBindlessCheckPass::Process() {
    258   InitializeInstBindlessCheck();
    259   return ProcessImpl();
    260 }
    261 
    262 }  // namespace opt
    263 }  // namespace spvtools
    264