1 // Copyright (c) 2018 Google LLC. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/val/validate.h" 16 17 #include <algorithm> 18 19 #include "source/opcode.h" 20 #include "source/val/instruction.h" 21 #include "source/val/validation_state.h" 22 23 namespace spvtools { 24 namespace val { 25 namespace { 26 27 spv_result_t ValidateFunction(ValidationState_t& _, const Instruction* inst) { 28 const auto function_type_id = inst->GetOperandAs<uint32_t>(3); 29 const auto function_type = _.FindDef(function_type_id); 30 if (!function_type || SpvOpTypeFunction != function_type->opcode()) { 31 return _.diag(SPV_ERROR_INVALID_ID, inst) 32 << "OpFunction Function Type <id> '" << _.getIdName(function_type_id) 33 << "' is not a function type."; 34 } 35 36 const auto return_id = function_type->GetOperandAs<uint32_t>(1); 37 if (return_id != inst->type_id()) { 38 return _.diag(SPV_ERROR_INVALID_ID, inst) 39 << "OpFunction Result Type <id> '" << _.getIdName(inst->type_id()) 40 << "' does not match the Function Type's return type <id> '" 41 << _.getIdName(return_id) << "'."; 42 } 43 44 for (auto& pair : inst->uses()) { 45 const auto* use = pair.first; 46 const std::vector<SpvOp> acceptable = { 47 SpvOpFunctionCall, 48 SpvOpEntryPoint, 49 SpvOpEnqueueKernel, 50 SpvOpGetKernelNDrangeSubGroupCount, 51 SpvOpGetKernelNDrangeMaxSubGroupSize, 52 SpvOpGetKernelWorkGroupSize, 53 SpvOpGetKernelPreferredWorkGroupSizeMultiple, 54 SpvOpGetKernelLocalSizeForSubgroupCount, 55 SpvOpGetKernelMaxNumSubgroups}; 56 if (std::find(acceptable.begin(), acceptable.end(), use->opcode()) == 57 acceptable.end()) { 58 return _.diag(SPV_ERROR_INVALID_ID, use) 59 << "Invalid use of function result id " << _.getIdName(inst->id()) 60 << "."; 61 } 62 } 63 64 return SPV_SUCCESS; 65 } 66 67 spv_result_t ValidateFunctionParameter(ValidationState_t& _, 68 const Instruction* inst) { 69 // NOTE: Find OpFunction & ensure OpFunctionParameter is not out of place. 70 size_t param_index = 0; 71 size_t inst_num = inst->LineNum() - 1; 72 if (inst_num == 0) { 73 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) 74 << "Function parameter cannot be the first instruction."; 75 } 76 77 auto func_inst = &_.ordered_instructions()[inst_num]; 78 while (--inst_num) { 79 func_inst = &_.ordered_instructions()[inst_num]; 80 if (func_inst->opcode() == SpvOpFunction) { 81 break; 82 } else if (func_inst->opcode() == SpvOpFunctionParameter) { 83 ++param_index; 84 } 85 } 86 87 if (func_inst->opcode() != SpvOpFunction) { 88 return _.diag(SPV_ERROR_INVALID_LAYOUT, inst) 89 << "Function parameter must be preceded by a function."; 90 } 91 92 const auto function_type_id = func_inst->GetOperandAs<uint32_t>(3); 93 const auto function_type = _.FindDef(function_type_id); 94 if (!function_type) { 95 return _.diag(SPV_ERROR_INVALID_ID, func_inst) 96 << "Missing function type definition."; 97 } 98 if (param_index >= function_type->words().size() - 3) { 99 return _.diag(SPV_ERROR_INVALID_ID, inst) 100 << "Too many OpFunctionParameters for " << func_inst->id() 101 << ": expected " << function_type->words().size() - 3 102 << " based on the function's type"; 103 } 104 105 const auto param_type = 106 _.FindDef(function_type->GetOperandAs<uint32_t>(param_index + 2)); 107 if (!param_type || inst->type_id() != param_type->id()) { 108 return _.diag(SPV_ERROR_INVALID_ID, inst) 109 << "OpFunctionParameter Result Type <id> '" 110 << _.getIdName(inst->type_id()) 111 << "' does not match the OpTypeFunction parameter " 112 "type of the same index."; 113 } 114 115 // Validate that PhysicalStorageBufferEXT have one of Restrict, Aliased, 116 // RestrictPointerEXT, or AliasedPointerEXT. 117 auto param_nonarray_type_id = param_type->id(); 118 while (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypeArray) { 119 param_nonarray_type_id = 120 _.FindDef(param_nonarray_type_id)->GetOperandAs<uint32_t>(1u); 121 } 122 if (_.GetIdOpcode(param_nonarray_type_id) == SpvOpTypePointer) { 123 auto param_nonarray_type = _.FindDef(param_nonarray_type_id); 124 if (param_nonarray_type->GetOperandAs<uint32_t>(1u) == 125 SpvStorageClassPhysicalStorageBufferEXT) { 126 // check for Aliased or Restrict 127 const auto& decorations = _.id_decorations(inst->id()); 128 129 bool foundAliased = std::any_of( 130 decorations.begin(), decorations.end(), [](const Decoration& d) { 131 return SpvDecorationAliased == d.dec_type(); 132 }); 133 134 bool foundRestrict = std::any_of( 135 decorations.begin(), decorations.end(), [](const Decoration& d) { 136 return SpvDecorationRestrict == d.dec_type(); 137 }); 138 139 if (!foundAliased && !foundRestrict) { 140 return _.diag(SPV_ERROR_INVALID_ID, inst) 141 << "OpFunctionParameter " << inst->id() 142 << ": expected Aliased or Restrict for PhysicalStorageBufferEXT " 143 "pointer."; 144 } 145 if (foundAliased && foundRestrict) { 146 return _.diag(SPV_ERROR_INVALID_ID, inst) 147 << "OpFunctionParameter " << inst->id() 148 << ": can't specify both Aliased and Restrict for " 149 "PhysicalStorageBufferEXT pointer."; 150 } 151 } else { 152 const auto pointee_type_id = 153 param_nonarray_type->GetOperandAs<uint32_t>(2); 154 const auto pointee_type = _.FindDef(pointee_type_id); 155 if (SpvOpTypePointer == pointee_type->opcode() && 156 pointee_type->GetOperandAs<uint32_t>(1u) == 157 SpvStorageClassPhysicalStorageBufferEXT) { 158 // check for AliasedPointerEXT/RestrictPointerEXT 159 const auto& decorations = _.id_decorations(inst->id()); 160 161 bool foundAliased = std::any_of( 162 decorations.begin(), decorations.end(), [](const Decoration& d) { 163 return SpvDecorationAliasedPointerEXT == d.dec_type(); 164 }); 165 166 bool foundRestrict = std::any_of( 167 decorations.begin(), decorations.end(), [](const Decoration& d) { 168 return SpvDecorationRestrictPointerEXT == d.dec_type(); 169 }); 170 171 if (!foundAliased && !foundRestrict) { 172 return _.diag(SPV_ERROR_INVALID_ID, inst) 173 << "OpFunctionParameter " << inst->id() 174 << ": expected AliasedPointerEXT or RestrictPointerEXT for " 175 "PhysicalStorageBufferEXT pointer."; 176 } 177 if (foundAliased && foundRestrict) { 178 return _.diag(SPV_ERROR_INVALID_ID, inst) 179 << "OpFunctionParameter " << inst->id() 180 << ": can't specify both AliasedPointerEXT and " 181 "RestrictPointerEXT for PhysicalStorageBufferEXT pointer."; 182 } 183 } 184 } 185 } 186 187 return SPV_SUCCESS; 188 } 189 190 spv_result_t ValidateFunctionCall(ValidationState_t& _, 191 const Instruction* inst) { 192 const auto function_id = inst->GetOperandAs<uint32_t>(2); 193 const auto function = _.FindDef(function_id); 194 if (!function || SpvOpFunction != function->opcode()) { 195 return _.diag(SPV_ERROR_INVALID_ID, inst) 196 << "OpFunctionCall Function <id> '" << _.getIdName(function_id) 197 << "' is not a function."; 198 } 199 200 auto return_type = _.FindDef(function->type_id()); 201 if (!return_type || return_type->id() != inst->type_id()) { 202 return _.diag(SPV_ERROR_INVALID_ID, inst) 203 << "OpFunctionCall Result Type <id> '" 204 << _.getIdName(inst->type_id()) 205 << "'s type does not match Function <id> '" 206 << _.getIdName(return_type->id()) << "'s return type."; 207 } 208 209 const auto function_type_id = function->GetOperandAs<uint32_t>(3); 210 const auto function_type = _.FindDef(function_type_id); 211 if (!function_type || function_type->opcode() != SpvOpTypeFunction) { 212 return _.diag(SPV_ERROR_INVALID_ID, inst) 213 << "Missing function type definition."; 214 } 215 216 const auto function_call_arg_count = inst->words().size() - 4; 217 const auto function_param_count = function_type->words().size() - 3; 218 if (function_param_count != function_call_arg_count) { 219 return _.diag(SPV_ERROR_INVALID_ID, inst) 220 << "OpFunctionCall Function <id>'s parameter count does not match " 221 "the argument count."; 222 } 223 224 for (size_t argument_index = 3, param_index = 2; 225 argument_index < inst->operands().size(); 226 argument_index++, param_index++) { 227 const auto argument_id = inst->GetOperandAs<uint32_t>(argument_index); 228 const auto argument = _.FindDef(argument_id); 229 if (!argument) { 230 return _.diag(SPV_ERROR_INVALID_ID, inst) 231 << "Missing argument " << argument_index - 3 << " definition."; 232 } 233 234 const auto argument_type = _.FindDef(argument->type_id()); 235 if (!argument_type) { 236 return _.diag(SPV_ERROR_INVALID_ID, inst) 237 << "Missing argument " << argument_index - 3 238 << " type definition."; 239 } 240 241 const auto parameter_type_id = 242 function_type->GetOperandAs<uint32_t>(param_index); 243 const auto parameter_type = _.FindDef(parameter_type_id); 244 if (!parameter_type || argument_type->id() != parameter_type->id()) { 245 return _.diag(SPV_ERROR_INVALID_ID, inst) 246 << "OpFunctionCall Argument <id> '" << _.getIdName(argument_id) 247 << "'s type does not match Function <id> '" 248 << _.getIdName(parameter_type_id) << "'s parameter type."; 249 } 250 } 251 return SPV_SUCCESS; 252 } 253 254 } // namespace 255 256 spv_result_t FunctionPass(ValidationState_t& _, const Instruction* inst) { 257 switch (inst->opcode()) { 258 case SpvOpFunction: 259 if (auto error = ValidateFunction(_, inst)) return error; 260 break; 261 case SpvOpFunctionParameter: 262 if (auto error = ValidateFunctionParameter(_, inst)) return error; 263 break; 264 case SpvOpFunctionCall: 265 if (auto error = ValidateFunctionCall(_, inst)) return error; 266 break; 267 default: 268 break; 269 } 270 271 return SPV_SUCCESS; 272 } 273 274 } // namespace val 275 } // namespace spvtools 276