Home | History | Annotate | Download | only in val
      1 // Copyright (c) 2018 Google LLC.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/val/validate.h"
     16 
     17 #include <algorithm>
     18 
     19 #include "source/opcode.h"
     20 #include "source/val/instruction.h"
     21 #include "source/val/validation_state.h"
     22 
     23 namespace spvtools {
     24 namespace val {
     25 namespace {
     26 
     27 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) {
     28   const auto function_type_id = inst->GetOperandAs<uint32_t>(3);
     29   const auto function_type = _.FindDef(function_type_id);
     30   if (!function_type || SpvOpTypeFunction != function_type->opcode()) {
     31     return _.diag(SPV_ERROR_INVALID_ID, inst)
     32            << "OpFunction Function Type <id> '" << _.getIdName(function_type_id)
     33            << "' is not a function type.";
     34   }
     35 
     36   const auto return_id = function_type->GetOperandAs<uint32_t>(1);
     37   if (return_id != inst->type_id()) {
     38     return _.diag(SPV_ERROR_INVALID_ID, inst)
     39            << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id())
     40            << "' does not match the Function Type's return type <id> '"
     41            << _.getIdName(return_id) << "'.";
     42   }
     43 
     44   for (auto& pair : inst->uses()) {
     45     const auto* use = pair.first;
     46     const std::vector<SpvOp> acceptable = {
     47         SpvOpFunctionCall,
     48         SpvOpEntryPoint,
     49         SpvOpEnqueueKernel,
     50         SpvOpGetKernelNDrangeSubGroupCount,
     51         SpvOpGetKernelNDrangeMaxSubGroupSize,
     52         SpvOpGetKernelWorkGroupSize,
     53         SpvOpGetKernelPreferredWorkGroupSizeMultiple,
     54         SpvOpGetKernelLocalSizeForSubgroupCount,
     55         SpvOpGetKernelMaxNumSubgroups};
     56     if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) ==
     57         acceptable.end()) {
     58       return _.diag(SPV_ERROR_INVALID_ID, use)
     59              << "Invalid use of function result id " << _.getIdName(inst->id())
     60              << ".";
     61     }
     62   }
     63 
     64   return SPV_SUCCESS;
     65 }
     66 
     67 spv_result_t ValidateFunctionParameter(ValidationState_t& _,
     68                                        const Instruction* inst) {
     69   // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place.
     70   size_t param_index = 0;
     71   size_t inst_num = inst->LineNum() - 1;
     72   if (inst_num == 0) {
     73     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
     74            << "Function parameter cannot be the first instruction.";
     75   }
     76 
     77   auto func_inst = &_.ordered_instructions()[inst_num];
     78   while (--inst_num) {
     79     func_inst = &_.ordered_instructions()[inst_num];
     80     if (func_inst->opcode() == SpvOpFunction) {
     81       break;
     82     } else if (func_inst->opcode() == SpvOpFunctionParameter) {
     83       ++param_index;
     84     }
     85   }
     86 
     87   if (func_inst->opcode() != SpvOpFunction) {
     88     return _.diag(SPV_ERROR_INVALID_LAYOUT, inst)
     89            << "Function parameter must be preceded by a function.";
     90   }
     91 
     92   const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3);
     93   const auto function_type = _.FindDef(function_type_id);
     94   if (!function_type) {
     95     return _.diag(SPV_ERROR_INVALID_ID, func_inst)
     96            << "Missing function type definition.";
     97   }
     98   if (param_index >= function_type->words().size() - 3) {
     99     return _.diag(SPV_ERROR_INVALID_ID, inst)
    100            << "Too many OpFunctionParameters for " << func_inst->id()
    101            << ": expected " << function_type->words().size() - 3
    102            << " based on the function's type";
    103   }
    104 
    105   const auto param_type =
    106       _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2));
    107   if (!param_type || inst->type_id() != param_type->id()) {
    108     return _.diag(SPV_ERROR_INVALID_ID, inst)
    109            << "OpFunctionParameter Result Type <id> '"
    110            << _.getIdName(inst->type_id())
    111            << "' does not match the OpTypeFunction parameter "
    112               "type of the same index.";
    113   }
    114   return SPV_SUCCESS;
    115 }
    116 
    117 spv_result_t ValidateFunctionCall(ValidationState_t& _,
    118                                   const Instruction* inst) {
    119   const auto function_id = inst->GetOperandAs<uint32_t>(2);
    120   const auto function = _.FindDef(function_id);
    121   if (!function || SpvOpFunction != function->opcode()) {
    122     return _.diag(SPV_ERROR_INVALID_ID, inst)
    123            << "OpFunctionCall Function <id> '" << _.getIdName(function_id)
    124            << "' is not a function.";
    125   }
    126 
    127   auto return_type = _.FindDef(function->type_id());
    128   if (!return_type || return_type->id() != inst->type_id()) {
    129     return _.diag(SPV_ERROR_INVALID_ID, inst)
    130            << "OpFunctionCall Result Type <id> '"
    131            << _.getIdName(inst->type_id())
    132            << "'s type does not match Function <id> '"
    133            << _.getIdName(return_type->id()) << "'s return type.";
    134   }
    135 
    136   const auto function_type_id = function->GetOperandAs<uint32_t>(3);
    137   const auto function_type = _.FindDef(function_type_id);
    138   if (!function_type || function_type->opcode() != SpvOpTypeFunction) {
    139     return _.diag(SPV_ERROR_INVALID_ID, inst)
    140            << "Missing function type definition.";
    141   }
    142 
    143   const auto function_call_arg_count = inst->words().size() - 4;
    144   const auto function_param_count = function_type->words().size() - 3;
    145   if (function_param_count != function_call_arg_count) {
    146     return _.diag(SPV_ERROR_INVALID_ID, inst)
    147            << "OpFunctionCall Function <id>'s parameter count does not match "
    148               "the argument count.";
    149   }
    150 
    151   for (size_t argument_index = 3, param_index = 2;
    152        argument_index < inst->operands().size();
    153        argument_index++, param_index++) {
    154     const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index);
    155     const auto argument = _.FindDef(argument_id);
    156     if (!argument) {
    157       return _.diag(SPV_ERROR_INVALID_ID, inst)
    158              << "Missing argument " << argument_index - 3 << " definition.";
    159     }
    160 
    161     const auto argument_type = _.FindDef(argument->type_id());
    162     if (!argument_type) {
    163       return _.diag(SPV_ERROR_INVALID_ID, inst)
    164              << "Missing argument " << argument_index - 3
    165              << " type definition.";
    166     }
    167 
    168     const auto parameter_type_id =
    169         function_type->GetOperandAs<uint32_t>(param_index);
    170     const auto parameter_type = _.FindDef(parameter_type_id);
    171     if (!parameter_type || argument_type->id() != parameter_type->id()) {
    172       return _.diag(SPV_ERROR_INVALID_ID, inst)
    173              << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id)
    174              << "'s type does not match Function <id> '"
    175              << _.getIdName(parameter_type_id) << "'s parameter type.";
    176     }
    177   }
    178   return SPV_SUCCESS;
    179 }
    180 
    181 }  // namespace
    182 
    183 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) {
    184   switch (inst->opcode()) {
    185     case SpvOpFunction:
    186       if (auto error = ValidateFunction(_, inst)) return error;
    187       break;
    188     case SpvOpFunctionParameter:
    189       if (auto error = ValidateFunctionParameter(_, inst)) return error;
    190       break;
    191     case SpvOpFunctionCall:
    192       if (auto error = ValidateFunctionCall(_, inst)) return error;
    193       break;
    194     default:
    195       break;
    196   }
    197 
    198   return SPV_SUCCESS;
    199 }
    200 
    201 }  // namespace val
    202 }  // namespace spvtools
    203