Home | History | Annotate | Download | only in val
      1 // Copyright (c) 2018 Google LLC.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/val/validate.h"
     16 
     17 #include <algorithm>
     18 
     19 #include "source/opcode.h"
     20 #include "source/val/instruction.h"
     21 #include "source/val/validation_state.h"
     22 
     23 namespace spvtools {
     24 namespace val {
     25 namespace {
     26 
     27 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
     28   const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
     29   const auto function_type = _.FindDef(function_type_id);
     30   if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
     31     return _.diag(SPV_ERROR_INVALID_ID, inst)
     32            << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
     33            << "' is not a function type.";
     34   }
     35 
     36   const auto return_id = function_type->GetOperandAs<uint32_t>(1);
     37   if (return_id != inst->type_id()) {
     38     return _.diag(SPV_ERROR_INVALID_ID, inst)
     39            << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
     40            << "' does not match the Function Type's return type <id> '"
     41            << _.getIdName(return_id) << "'.";
     42   }
     43 
     44   for (auto& pair : inst->uses()) {
     45     const auto* use = pair.first;
     46     const std::vector<SpvOp> acceptable = {
     47         SpvOpFunctionCall,
     48         SpvOpEntryPoint,
     49         SpvOpEnqueueKernel,
     50         SpvOpGetKernelNDrangeSubGroupCount,
     51         SpvOpGetKernelNDrangeMaxSubGroupSize,
     52         SpvOpGetKernelWorkGroupSize,
     53         SpvOpGetKernelPreferredWorkGroupSizeMultiple,
     54         SpvOpGetKernelLocalSizeForSubgroupCount,
     55         SpvOpGetKernelMaxNumSubgroups};
     56     if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
     57         acceptable.end()) {
     58       return _.diag(SPV_ERROR_INVALID_ID, use)
     59              << "Invalid use of function result id " << _.getIdName(inst->id())
     60              << ".";
     61     }
     62   }
     63 
     64   return SPV_SUCCESS;
     65 }
     66 
     67 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
     68                                        const Instruction* inst) {
     69   // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
     70   size_t param_index = 0;
     71   size_t inst_num = inst->LineNum() - 1;
     72   if (inst_num == 0) {
     73     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
     74            << "Function parameter cannot be the first instruction.";
     75   }
     76 
     77   auto func_inst = &_.ordered_instructions()[inst_num];
     78   while (--inst_num) {
     79     func_inst = &_.ordered_instructions()[inst_num];
     80     if (func_inst->opcode() == SpvOpFunction) {
     81       break;
     82     } else if (func_inst->opcode() == SpvOpFunctionParameter) {
     83       ++param_index;
     84     }
     85   }
     86 
     87   if (func_inst->opcode() != SpvOpFunction) {
     88     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
     89            << "Function parameter must be preceded by a function.";
     90   }
     91 
     92   const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
     93   const auto function_type = _.FindDef(function_type_id);
     94   if (!function_type) {
     95     return _.diag(SPV_ERROR_INVALID_ID, func_inst)
     96            << "Missing function type definition.";
     97   }
     98   if (param_index >= function_type->words().size() - 3) {
     99     return _.diag(SPV_ERROR_INVALID_ID, inst)
    100            << "Too many OpFunctionParameters for " << func_inst->id()
    101            << ": expected " << function_type->words().size() - 3
    102            << " based on the function's type";
    103   }
    104 
    105   const auto param_type =
    106       _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
    107   if (!param_type || inst->type_id() != param_type->id()) {
    108     return _.diag(SPV_ERROR_INVALID_ID, inst)
    109            << "OpFunctionParameter Result Type <id> '"
    110            << _.getIdName(inst->type_id())
    111            << "' does not match the OpTypeFunction parameter "
    112               "type of the same index.";
    113   }
    114 
    115   // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased,
    116   // RestrictPointerEXT, or AliasedPointerEXT.
    117   auto param_nonarray_type_id = param_type->id();
    118   while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) {
    119     param_nonarray_type_id =
    120         _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u);
    121   }
    122   if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) {
    123     auto param_nonarray_type = _.FindDef(param_nonarray_type_id);
    124     if (param_nonarray_type->GetOperandAs<uint32_t>(1u) ==
    125         SpvStorageClassPhysicalStorageBufferEXT) {
    126       // check for Aliased or Restrict
    127       const auto& decorations = _.id_decorations(inst->id());
    128 
    129       bool foundAliased = std::any_of(
    130           decorations.begin(), decorations.end(), [](const Decoration& d) {
    131             return SpvDecorationAliased == d.dec_type();
    132           });
    133 
    134       bool foundRestrict = std::any_of(
    135           decorations.begin(), decorations.end(), [](const Decoration& d) {
    136             return SpvDecorationRestrict == d.dec_type();
    137           });
    138 
    139       if (!foundAliased && !foundRestrict) {
    140         return _.diag(SPV_ERROR_INVALID_ID, inst)
    141                << "OpFunctionParameter " << inst->id()
    142                << ": expected Aliased or Restrict for PhysicalStorageBufferEXT "
    143                   "pointer.";
    144       }
    145       if (foundAliased && foundRestrict) {
    146         return _.diag(SPV_ERROR_INVALID_ID, inst)
    147                << "OpFunctionParameter " << inst->id()
    148                << ": can't specify both Aliased and Restrict for "
    149                   "PhysicalStorageBufferEXT pointer.";
    150       }
    151     } else {
    152       const auto pointee_type_id =
    153           param_nonarray_type->GetOperandAs<uint32_t>(2);
    154       const auto pointee_type = _.FindDef(pointee_type_id);
    155       if (SpvOpTypePointer == pointee_type->opcode() &&
    156           pointee_type->GetOperandAs<uint32_t>(1u) ==
    157               SpvStorageClassPhysicalStorageBufferEXT) {
    158         // check for AliasedPointerEXT/RestrictPointerEXT
    159         const auto& decorations = _.id_decorations(inst->id());
    160 
    161         bool foundAliased = std::any_of(
    162             decorations.begin(), decorations.end(), [](const Decoration& d) {
    163               return SpvDecorationAliasedPointerEXT == d.dec_type();
    164             });
    165 
    166         bool foundRestrict = std::any_of(
    167             decorations.begin(), decorations.end(), [](const Decoration& d) {
    168               return SpvDecorationRestrictPointerEXT == d.dec_type();
    169             });
    170 
    171         if (!foundAliased && !foundRestrict) {
    172           return _.diag(SPV_ERROR_INVALID_ID, inst)
    173                  << "OpFunctionParameter " << inst->id()
    174                  << ": expected AliasedPointerEXT or RestrictPointerEXT for "
    175                     "PhysicalStorageBufferEXT pointer.";
    176         }
    177         if (foundAliased && foundRestrict) {
    178           return _.diag(SPV_ERROR_INVALID_ID, inst)
    179                  << "OpFunctionParameter " << inst->id()
    180                  << ": can't specify both AliasedPointerEXT and "
    181                     "RestrictPointerEXT for PhysicalStorageBufferEXT pointer.";
    182         }
    183       }
    184     }
    185   }
    186 
    187   return SPV_SUCCESS;
    188 }
    189 
    190 spv_result_t ValidateFunctionCall(ValidationState_t& _,
    191                                   const Instruction* inst) {
    192   const auto function_id = inst->GetOperandAs<uint32_t>(2);
    193   const auto function = _.FindDef(function_id);
    194   if (!function || SpvOpFunction != function->opcode()) {
    195     return _.diag(SPV_ERROR_INVALID_ID, inst)
    196            << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
    197            << "' is not a function.";
    198   }
    199 
    200   auto return_type = _.FindDef(function->type_id());
    201   if (!return_type || return_type->id() != inst->type_id()) {
    202     return _.diag(SPV_ERROR_INVALID_ID, inst)
    203            << "OpFunctionCall Result Type <id> '"
    204            << _.getIdName(inst->type_id())
    205            << "'s type does not match Function <id> '"
    206            << _.getIdName(return_type->id()) << "'s return type.";
    207   }
    208 
    209   const auto function_type_id = function->GetOperandAs<uint32_t>(3);
    210   const auto function_type = _.FindDef(function_type_id);
    211   if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
    212     return _.diag(SPV_ERROR_INVALID_ID, inst)
    213            << "Missing function type definition.";
    214   }
    215 
    216   const auto function_call_arg_count = inst->words().size() - 4;
    217   const auto function_param_count = function_type->words().size() - 3;
    218   if (function_param_count != function_call_arg_count) {
    219     return _.diag(SPV_ERROR_INVALID_ID, inst)
    220            << "OpFunctionCall Function <id>'s parameter count does not match "
    221               "the argument count.";
    222   }
    223 
    224   for (size_t argument_index = 3, param_index = 2;
    225        argument_index < inst->operands().size();
    226        argument_index++, param_index++) {
    227     const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
    228     const auto argument = _.FindDef(argument_id);
    229     if (!argument) {
    230       return _.diag(SPV_ERROR_INVALID_ID, inst)
    231              << "Missing argument " << argument_index - 3 << " definition.";
    232     }
    233 
    234     const auto argument_type = _.FindDef(argument->type_id());
    235     if (!argument_type) {
    236       return _.diag(SPV_ERROR_INVALID_ID, inst)
    237              << "Missing argument " << argument_index - 3
    238              << " type definition.";
    239     }
    240 
    241     const auto parameter_type_id =
    242         function_type->GetOperandAs<uint32_t>(param_index);
    243     const auto parameter_type = _.FindDef(parameter_type_id);
    244     if (!parameter_type || argument_type->id() != parameter_type->id()) {
    245       return _.diag(SPV_ERROR_INVALID_ID, inst)
    246              << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
    247              << "'s type does not match Function <id> '"
    248              << _.getIdName(parameter_type_id) << "'s parameter type.";
    249     }
    250   }
    251   return SPV_SUCCESS;
    252 }
    253 
    254 }  // namespace
    255 
    256 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
    257   switch (inst->opcode()) {
    258     case SpvOpFunction:
    259       if (auto error = ValidateFunction(_, inst)) return error;
    260       break;
    261     case SpvOpFunctionParameter:
    262       if (auto error = ValidateFunctionParameter(_, inst)) return error;
    263       break;
    264     case SpvOpFunctionCall:
    265       if (auto error = ValidateFunctionCall(_, inst)) return error;
    266       break;
    267     default:
    268       break;
    269   }
    270 
    271   return SPV_SUCCESS;
    272 }
    273 
    274 }  // namespace val
    275 }  // namespace spvtools
    276