1 // Copyright (c) 2018 Google LLC. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/val/validate.h" 16 17 #include <algorithm> 18 19 #include "source/opcode.h" 20 #include "source/val/instruction.h" 21 #include "source/val/validation_state.h" 22 23 namespace spvtools { 24 namespace val { 25 namespace { 26 27 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { 28 const auto function_type_id = inst->GetOperandAs<uint32_t>(3); 29 const auto function_type = _.FindDef(function_type_id); 30 if (!function_type || SpvOpTypeFunction != function_type->opcode()) { 31 return _.diag(SPV_ERROR_INVALID_ID, inst) 32 << "OpFunction Function Type <id> '" << _.getIdName(function_type_id) 33 << "' is not a function type."; 34 } 35 36 const auto return_id = function_type->GetOperandAs<uint32_t>(1); 37 if (return_id != inst->type_id()) { 38 return _.diag(SPV_ERROR_INVALID_ID, inst) 39 << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id()) 40 << "' does not match the Function Type's return type <id> '" 41 << _.getIdName(return_id) << "'."; 42 } 43 44 for (auto& pair : inst->uses()) { 45 const auto* use = pair.first; 46 const std::vector<SpvOp> acceptable = { 47 SpvOpFunctionCall, 48 SpvOpEntryPoint, 49 SpvOpEnqueueKernel, 50 SpvOpGetKernelNDrangeSubGroupCount, 51 SpvOpGetKernelNDrangeMaxSubGroupSize, 52 SpvOpGetKernelWorkGroupSize, 53 SpvOpGetKernelPreferredWorkGroupSizeMultiple, 54 SpvOpGetKernelLocalSizeForSubgroupCount, 55 SpvOpGetKernelMaxNumSubgroups}; 56 if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == 57 acceptable.end()) { 58 return _.diag(SPV_ERROR_INVALID_ID, use) 59 << "Invalid use of function result id " << _.getIdName(inst->id()) 60 << "."; 61 } 62 } 63 64 return SPV_SUCCESS; 65 } 66 67 spv_result_t ValidateFunctionParameter(ValidationState_t& _, 68 const Instruction* inst) { 69 // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. 70 size_t param_index = 0; 71 size_t inst_num = inst->LineNum() - 1; 72 if (inst_num == 0) { 73 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) 74 << "Function parameter cannot be the first instruction."; 75 } 76 77 auto func_inst = &_.ordered_instructions()[inst_num]; 78 while (--inst_num) { 79 func_inst = &_.ordered_instructions()[inst_num]; 80 if (func_inst->opcode() == SpvOpFunction) { 81 break; 82 } else if (func_inst->opcode() == SpvOpFunctionParameter) { 83 ++param_index; 84 } 85 } 86 87 if (func_inst->opcode() != SpvOpFunction) { 88 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) 89 << "Function parameter must be preceded by a function."; 90 } 91 92 const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3); 93 const auto function_type = _.FindDef(function_type_id); 94 if (!function_type) { 95 return _.diag(SPV_ERROR_INVALID_ID, func_inst) 96 << "Missing function type definition."; 97 } 98 if (param_index >= function_type->words().size() - 3) { 99 return _.diag(SPV_ERROR_INVALID_ID, inst) 100 << "Too many OpFunctionParameters for " << func_inst->id() 101 << ": expected " << function_type->words().size() - 3 102 << " based on the function's type"; 103 } 104 105 const auto param_type = 106 _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2)); 107 if (!param_type || inst->type_id() != param_type->id()) { 108 return _.diag(SPV_ERROR_INVALID_ID, inst) 109 << "OpFunctionParameter Result Type <id> '" 110 << _.getIdName(inst->type_id()) 111 << "' does not match the OpTypeFunction parameter " 112 "type of the same index."; 113 } 114 return SPV_SUCCESS; 115 } 116 117 spv_result_t ValidateFunctionCall(ValidationState_t& _, 118 const Instruction* inst) { 119 const auto function_id = inst->GetOperandAs<uint32_t>(2); 120 const auto function = _.FindDef(function_id); 121 if (!function || SpvOpFunction != function->opcode()) { 122 return _.diag(SPV_ERROR_INVALID_ID, inst) 123 << "OpFunctionCall Function <id> '" << _.getIdName(function_id) 124 << "' is not a function."; 125 } 126 127 auto return_type = _.FindDef(function->type_id()); 128 if (!return_type || return_type->id() != inst->type_id()) { 129 return _.diag(SPV_ERROR_INVALID_ID, inst) 130 << "OpFunctionCall Result Type <id> '" 131 << _.getIdName(inst->type_id()) 132 << "'s type does not match Function <id> '" 133 << _.getIdName(return_type->id()) << "'s return type."; 134 } 135 136 const auto function_type_id = function->GetOperandAs<uint32_t>(3); 137 const auto function_type = _.FindDef(function_type_id); 138 if (!function_type || function_type->opcode() != SpvOpTypeFunction) { 139 return _.diag(SPV_ERROR_INVALID_ID, inst) 140 << "Missing function type definition."; 141 } 142 143 const auto function_call_arg_count = inst->words().size() - 4; 144 const auto function_param_count = function_type->words().size() - 3; 145 if (function_param_count != function_call_arg_count) { 146 return _.diag(SPV_ERROR_INVALID_ID, inst) 147 << "OpFunctionCall Function <id>'s parameter count does not match " 148 "the argument count."; 149 } 150 151 for (size_t argument_index = 3, param_index = 2; 152 argument_index < inst->operands().size(); 153 argument_index++, param_index++) { 154 const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index); 155 const auto argument = _.FindDef(argument_id); 156 if (!argument) { 157 return _.diag(SPV_ERROR_INVALID_ID, inst) 158 << "Missing argument " << argument_index - 3 << " definition."; 159 } 160 161 const auto argument_type = _.FindDef(argument->type_id()); 162 if (!argument_type) { 163 return _.diag(SPV_ERROR_INVALID_ID, inst) 164 << "Missing argument " << argument_index - 3 165 << " type definition."; 166 } 167 168 const auto parameter_type_id = 169 function_type->GetOperandAs<uint32_t>(param_index); 170 const auto parameter_type = _.FindDef(parameter_type_id); 171 if (!parameter_type || argument_type->id() != parameter_type->id()) { 172 return _.diag(SPV_ERROR_INVALID_ID, inst) 173 << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id) 174 << "'s type does not match Function <id> '" 175 << _.getIdName(parameter_type_id) << "'s parameter type."; 176 } 177 } 178 return SPV_SUCCESS; 179 } 180 181 } // namespace 182 183 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { 184 switch (inst->opcode()) { 185 case SpvOpFunction: 186 if (auto error = ValidateFunction(_, inst)) return error; 187 break; 188 case SpvOpFunctionParameter: 189 if (auto error = ValidateFunctionParameter(_, inst)) return error; 190 break; 191 case SpvOpFunctionCall: 192 if (auto error = ValidateFunctionCall(_, inst)) return error; 193 break; 194 default: 195 break; 196 } 197 198 return SPV_SUCCESS; 199 } 200 201 } // namespace val 202 } // namespace spvtools 203