1 // Copyright (c) 2018 Google LLC. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/val/validate.h" 16 17 #include "source/opcode.h" 18 #include "source/val/instruction.h" 19 #include "source/val/validation_state.h" 20 21 namespace spvtools { 22 namespace val { 23 namespace { 24 25 spv_result_t ValidateConstantBool(ValidationState_t& _, 26 const Instruction* inst) { 27 auto type = _.FindDef(inst->type_id()); 28 if (!type || type->opcode() != SpvOpTypeBool) { 29 return _.diag(SPV_ERROR_INVALID_ID, inst) 30 << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '" 31 << _.getIdName(inst->type_id()) << "' is not a boolean type."; 32 } 33 34 return SPV_SUCCESS; 35 } 36 37 spv_result_t ValidateConstantComposite(ValidationState_t& _, 38 const Instruction* inst) { 39 std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode()); 40 41 const auto result_type = _.FindDef(inst->type_id()); 42 if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) { 43 return _.diag(SPV_ERROR_INVALID_ID, inst) 44 << opcode_name << " Result Type <id> '" 45 << _.getIdName(inst->type_id()) << "' is not a composite type."; 46 } 47 48 const auto constituent_count = inst->words().size() - 3; 49 switch (result_type->opcode()) { 50 case SpvOpTypeVector: { 51 const auto component_count = result_type->GetOperandAs<uint32_t>(2); 52 if (component_count != constituent_count) { 53 // TODO: Output ID's on diagnostic 54 return _.diag(SPV_ERROR_INVALID_ID, inst) 55 << opcode_name 56 << " Constituent <id> count does not match " 57 "Result Type <id> '" 58 << _.getIdName(result_type->id()) 59 << "'s vector component count."; 60 } 61 const auto component_type = 62 _.FindDef(result_type->GetOperandAs<uint32_t>(1)); 63 if (!component_type) { 64 return _.diag(SPV_ERROR_INVALID_ID, result_type) 65 << "Component type is not defined."; 66 } 67 for (size_t constituent_index = 2; 68 constituent_index < inst->operands().size(); constituent_index++) { 69 const auto constituent_id = 70 inst->GetOperandAs<uint32_t>(constituent_index); 71 const auto constituent = _.FindDef(constituent_id); 72 if (!constituent || 73 !spvOpcodeIsConstantOrUndef(constituent->opcode())) { 74 return _.diag(SPV_ERROR_INVALID_ID, inst) 75 << opcode_name << " Constituent <id> '" 76 << _.getIdName(constituent_id) 77 << "' is not a constant or undef."; 78 } 79 const auto constituent_result_type = _.FindDef(constituent->type_id()); 80 if (!constituent_result_type || 81 component_type->opcode() != constituent_result_type->opcode()) { 82 return _.diag(SPV_ERROR_INVALID_ID, inst) 83 << opcode_name << " Constituent <id> '" 84 << _.getIdName(constituent_id) 85 << "'s type does not match Result Type <id> '" 86 << _.getIdName(result_type->id()) << "'s vector element type."; 87 } 88 } 89 } break; 90 case SpvOpTypeMatrix: { 91 const auto column_count = result_type->GetOperandAs<uint32_t>(2); 92 if (column_count != constituent_count) { 93 // TODO: Output ID's on diagnostic 94 return _.diag(SPV_ERROR_INVALID_ID, inst) 95 << opcode_name 96 << " Constituent <id> count does not match " 97 "Result Type <id> '" 98 << _.getIdName(result_type->id()) << "'s matrix column count."; 99 } 100 101 const auto column_type = _.FindDef(result_type->words()[2]); 102 if (!column_type) { 103 return _.diag(SPV_ERROR_INVALID_ID, result_type) 104 << "Column type is not defined."; 105 } 106 const auto component_count = column_type->GetOperandAs<uint32_t>(2); 107 const auto component_type = 108 _.FindDef(column_type->GetOperandAs<uint32_t>(1)); 109 if (!component_type) { 110 return _.diag(SPV_ERROR_INVALID_ID, column_type) 111 << "Component type is not defined."; 112 } 113 114 for (size_t constituent_index = 2; 115 constituent_index < inst->operands().size(); constituent_index++) { 116 const auto constituent_id = 117 inst->GetOperandAs<uint32_t>(constituent_index); 118 const auto constituent = _.FindDef(constituent_id); 119 if (!constituent || 120 !(SpvOpConstantComposite == constituent->opcode() || 121 SpvOpSpecConstantComposite == constituent->opcode() || 122 SpvOpUndef == constituent->opcode())) { 123 // The message says "... or undef" because the spec does not say 124 // undef is a constant. 125 return _.diag(SPV_ERROR_INVALID_ID, inst) 126 << opcode_name << " Constituent <id> '" 127 << _.getIdName(constituent_id) 128 << "' is not a constant composite or undef."; 129 } 130 const auto vector = _.FindDef(constituent->type_id()); 131 if (!vector) { 132 return _.diag(SPV_ERROR_INVALID_ID, constituent) 133 << "Result type is not defined."; 134 } 135 if (column_type->opcode() != vector->opcode()) { 136 return _.diag(SPV_ERROR_INVALID_ID, inst) 137 << opcode_name << " Constituent <id> '" 138 << _.getIdName(constituent_id) 139 << "' type does not match Result Type <id> '" 140 << _.getIdName(result_type->id()) << "'s matrix column type."; 141 } 142 const auto vector_component_type = 143 _.FindDef(vector->GetOperandAs<uint32_t>(1)); 144 if (component_type->id() != vector_component_type->id()) { 145 return _.diag(SPV_ERROR_INVALID_ID, inst) 146 << opcode_name << " Constituent <id> '" 147 << _.getIdName(constituent_id) 148 << "' component type does not match Result Type <id> '" 149 << _.getIdName(result_type->id()) 150 << "'s matrix column component type."; 151 } 152 if (component_count != vector->words()[3]) { 153 return _.diag(SPV_ERROR_INVALID_ID, inst) 154 << opcode_name << " Constituent <id> '" 155 << _.getIdName(constituent_id) 156 << "' vector component count does not match Result Type <id> '" 157 << _.getIdName(result_type->id()) 158 << "'s vector component count."; 159 } 160 } 161 } break; 162 case SpvOpTypeArray: { 163 auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1)); 164 if (!element_type) { 165 return _.diag(SPV_ERROR_INVALID_ID, result_type) 166 << "Element type is not defined."; 167 } 168 const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2)); 169 if (!length) { 170 return _.diag(SPV_ERROR_INVALID_ID, result_type) 171 << "Length is not defined."; 172 } 173 bool is_int32; 174 bool is_const; 175 uint32_t value; 176 std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id()); 177 if (is_int32 && is_const && value != constituent_count) { 178 return _.diag(SPV_ERROR_INVALID_ID, inst) 179 << opcode_name 180 << " Constituent count does not match " 181 "Result Type <id> '" 182 << _.getIdName(result_type->id()) << "'s array length."; 183 } 184 for (size_t constituent_index = 2; 185 constituent_index < inst->operands().size(); constituent_index++) { 186 const auto constituent_id = 187 inst->GetOperandAs<uint32_t>(constituent_index); 188 const auto constituent = _.FindDef(constituent_id); 189 if (!constituent || 190 !spvOpcodeIsConstantOrUndef(constituent->opcode())) { 191 return _.diag(SPV_ERROR_INVALID_ID, inst) 192 << opcode_name << " Constituent <id> '" 193 << _.getIdName(constituent_id) 194 << "' is not a constant or undef."; 195 } 196 const auto constituent_type = _.FindDef(constituent->type_id()); 197 if (!constituent_type) { 198 return _.diag(SPV_ERROR_INVALID_ID, constituent) 199 << "Result type is not defined."; 200 } 201 if (element_type->id() != constituent_type->id()) { 202 return _.diag(SPV_ERROR_INVALID_ID, inst) 203 << opcode_name << " Constituent <id> '" 204 << _.getIdName(constituent_id) 205 << "'s type does not match Result Type <id> '" 206 << _.getIdName(result_type->id()) << "'s array element type."; 207 } 208 } 209 } break; 210 case SpvOpTypeStruct: { 211 const auto member_count = result_type->words().size() - 2; 212 if (member_count != constituent_count) { 213 return _.diag(SPV_ERROR_INVALID_ID, inst) 214 << opcode_name << " Constituent <id> '" 215 << _.getIdName(inst->type_id()) 216 << "' count does not match Result Type <id> '" 217 << _.getIdName(result_type->id()) << "'s struct member count."; 218 } 219 for (uint32_t constituent_index = 2, member_index = 1; 220 constituent_index < inst->operands().size(); 221 constituent_index++, member_index++) { 222 const auto constituent_id = 223 inst->GetOperandAs<uint32_t>(constituent_index); 224 const auto constituent = _.FindDef(constituent_id); 225 if (!constituent || 226 !spvOpcodeIsConstantOrUndef(constituent->opcode())) { 227 return _.diag(SPV_ERROR_INVALID_ID, inst) 228 << opcode_name << " Constituent <id> '" 229 << _.getIdName(constituent_id) 230 << "' is not a constant or undef."; 231 } 232 const auto constituent_type = _.FindDef(constituent->type_id()); 233 if (!constituent_type) { 234 return _.diag(SPV_ERROR_INVALID_ID, constituent) 235 << "Result type is not defined."; 236 } 237 238 const auto member_type_id = 239 result_type->GetOperandAs<uint32_t>(member_index); 240 const auto member_type = _.FindDef(member_type_id); 241 if (!member_type || member_type->id() != constituent_type->id()) { 242 return _.diag(SPV_ERROR_INVALID_ID, inst) 243 << opcode_name << " Constituent <id> '" 244 << _.getIdName(constituent_id) 245 << "' type does not match the Result Type <id> '" 246 << _.getIdName(result_type->id()) << "'s member type."; 247 } 248 } 249 } break; 250 default: 251 break; 252 } 253 return SPV_SUCCESS; 254 } 255 256 spv_result_t ValidateConstantSampler(ValidationState_t& _, 257 const Instruction* inst) { 258 const auto result_type = _.FindDef(inst->type_id()); 259 if (!result_type || result_type->opcode() != SpvOpTypeSampler) { 260 return _.diag(SPV_ERROR_INVALID_ID, result_type) 261 << "OpConstantSampler Result Type <id> '" 262 << _.getIdName(inst->type_id()) << "' is not a sampler type."; 263 } 264 265 return SPV_SUCCESS; 266 } 267 268 // True if instruction defines a type that can have a null value, as defined by 269 // the SPIR-V spec. Tracks composite-type components through module to check 270 // nullability transitively. 271 bool IsTypeNullable(const std::vector<uint32_t>& instruction, 272 const ValidationState_t& _) { 273 uint16_t opcode; 274 uint16_t word_count; 275 spvOpcodeSplit(instruction[0], &word_count, &opcode); 276 switch (static_cast<SpvOp>(opcode)) { 277 case SpvOpTypeBool: 278 case SpvOpTypeInt: 279 case SpvOpTypeFloat: 280 case SpvOpTypePointer: 281 case SpvOpTypeEvent: 282 case SpvOpTypeDeviceEvent: 283 case SpvOpTypeReserveId: 284 case SpvOpTypeQueue: 285 return true; 286 case SpvOpTypeArray: 287 case SpvOpTypeMatrix: 288 case SpvOpTypeVector: { 289 auto base_type = _.FindDef(instruction[2]); 290 return base_type && IsTypeNullable(base_type->words(), _); 291 } 292 case SpvOpTypeStruct: { 293 for (size_t elementIndex = 2; elementIndex < instruction.size(); 294 ++elementIndex) { 295 auto element = _.FindDef(instruction[elementIndex]); 296 if (!element || !IsTypeNullable(element->words(), _)) return false; 297 } 298 return true; 299 } 300 default: 301 return false; 302 } 303 } 304 305 spv_result_t ValidateConstantNull(ValidationState_t& _, 306 const Instruction* inst) { 307 const auto result_type = _.FindDef(inst->type_id()); 308 if (!result_type || !IsTypeNullable(result_type->words(), _)) { 309 return _.diag(SPV_ERROR_INVALID_ID, inst) 310 << "OpConstantNull Result Type <id> '" 311 << _.getIdName(inst->type_id()) << "' cannot have a null value."; 312 } 313 314 return SPV_SUCCESS; 315 } 316 317 } // namespace 318 319 spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) { 320 switch (inst->opcode()) { 321 case SpvOpConstantTrue: 322 case SpvOpConstantFalse: 323 case SpvOpSpecConstantTrue: 324 case SpvOpSpecConstantFalse: 325 if (auto error = ValidateConstantBool(_, inst)) return error; 326 break; 327 case SpvOpConstantComposite: 328 case SpvOpSpecConstantComposite: 329 if (auto error = ValidateConstantComposite(_, inst)) return error; 330 break; 331 case SpvOpConstantSampler: 332 if (auto error = ValidateConstantSampler(_, inst)) return error; 333 break; 334 case SpvOpConstantNull: 335 if (auto error = ValidateConstantNull(_, inst)) return error; 336 break; 337 default: 338 break; 339 } 340 341 return SPV_SUCCESS; 342 } 343 344 } // namespace val 345 } // namespace spvtools 346