1 // Copyright (c) 2015-2016 The Khronos Group Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "validate.h" 16 17 #include <cassert> 18 #include <cstdio> 19 20 #include <algorithm> 21 #include <functional> 22 #include <iterator> 23 #include <memory> 24 #include <sstream> 25 #include <string> 26 #include <vector> 27 28 #include "binary.h" 29 #include "diagnostic.h" 30 #include "enum_string_mapping.h" 31 #include "extensions.h" 32 #include "instruction.h" 33 #include "opcode.h" 34 #include "operand.h" 35 #include "spirv-tools/libspirv.h" 36 #include "spirv_constant.h" 37 #include "spirv_endian.h" 38 #include "spirv_validator_options.h" 39 #include "val/construct.h" 40 #include "val/function.h" 41 #include "val/validation_state.h" 42 43 using std::function; 44 using std::ostream_iterator; 45 using std::placeholders::_1; 46 using std::string; 47 using std::stringstream; 48 using std::transform; 49 using std::vector; 50 51 using libspirv::CfgPass; 52 using libspirv::Extension; 53 using libspirv::InstructionPass; 54 using libspirv::ModuleLayoutPass; 55 using libspirv::DataRulesPass; 56 using libspirv::IdPass; 57 using libspirv::ValidationState_t; 58 59 spv_result_t spvValidateIDs(const spv_instruction_t* pInsts, 60 const uint64_t count, 61 const spv_opcode_table opcodeTable, 62 const spv_operand_table operandTable, 63 const spv_ext_inst_table extInstTable, 64 const ValidationState_t& state, 65 spv_position position) { 66 position->index = SPV_INDEX_INSTRUCTION; 67 if (auto error = 68 spvValidateInstructionIDs(pInsts, count, opcodeTable, operandTable, 69 extInstTable, state, position)) 70 return error; 71 return SPV_SUCCESS; 72 } 73 74 namespace { 75 76 // TODO(umar): Validate header 77 // TODO(umar): The binary parser validates the magic word, and the length of the 78 // header, but nothing else. 79 spv_result_t setHeader(void* user_data, spv_endianness_t endian, uint32_t magic, 80 uint32_t version, uint32_t generator, uint32_t id_bound, 81 uint32_t reserved) { 82 // Record the ID bound so that the validator can ensure no ID is out of bound. 83 ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data)); 84 _.setIdBound(id_bound); 85 86 (void)endian; 87 (void)magic; 88 (void)version; 89 (void)generator; 90 (void)id_bound; 91 (void)reserved; 92 return SPV_SUCCESS; 93 } 94 95 // Improves diagnostic messages by collecting names of IDs 96 // NOTE: This function returns void and is not involved in validation 97 void DebugInstructionPass(ValidationState_t& _, 98 const spv_parsed_instruction_t* inst) { 99 switch (inst->opcode) { 100 case SpvOpName: { 101 const uint32_t target = *(inst->words + inst->operands[0].offset); 102 const char* str = 103 reinterpret_cast<const char*>(inst->words + inst->operands[1].offset); 104 _.AssignNameToId(target, str); 105 } break; 106 case SpvOpMemberName: { 107 const uint32_t target = *(inst->words + inst->operands[0].offset); 108 const char* str = 109 reinterpret_cast<const char*>(inst->words + inst->operands[2].offset); 110 _.AssignNameToId(target, str); 111 } break; 112 case SpvOpSourceContinued: 113 case SpvOpSource: 114 case SpvOpSourceExtension: 115 case SpvOpString: 116 case SpvOpLine: 117 case SpvOpNoLine: 118 119 default: 120 break; 121 } 122 } 123 124 // Parses OpExtension instruction and registers extension. 125 void RegisterExtension(ValidationState_t& _, 126 const spv_parsed_instruction_t* inst) { 127 const std::string extension_str = libspirv::GetExtensionString(inst); 128 Extension extension; 129 if (!GetExtensionFromString(extension_str, &extension)) { 130 // The error will be logged in the ProcessInstruction pass. 131 return; 132 } 133 134 _.RegisterExtension(extension); 135 } 136 137 // Parses the beginning of the module searching for OpExtension instructions. 138 // Registers extensions if recognized. Returns SPV_REQUESTED_TERMINATION 139 // once an instruction which is not SpvOpCapability and SpvOpExtension is 140 // encountered. According to the SPIR-V spec extensions are declared after 141 // capabilities and before everything else. 142 spv_result_t ProcessExtensions( 143 void* user_data, const spv_parsed_instruction_t* inst) { 144 const SpvOp opcode = static_cast<SpvOp>(inst->opcode); 145 if (opcode == SpvOpCapability) 146 return SPV_SUCCESS; 147 148 if (opcode == SpvOpExtension) { 149 ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data)); 150 RegisterExtension(_, inst); 151 return SPV_SUCCESS; 152 } 153 154 // OpExtension block is finished, requesting termination. 155 return SPV_REQUESTED_TERMINATION; 156 } 157 158 spv_result_t ProcessInstruction(void* user_data, 159 const spv_parsed_instruction_t* inst) { 160 ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data)); 161 _.increment_instruction_count(); 162 if (static_cast<SpvOp>(inst->opcode) == SpvOpEntryPoint) { 163 const auto entry_point = inst->words[2]; 164 _.RegisterEntryPointId(entry_point); 165 // Operand 3 and later are the <id> of interfaces for the entry point. 166 for (int i = 3; i < inst->num_operands; ++i) { 167 _.RegisterInterfaceForEntryPoint(entry_point, 168 inst->words[inst->operands[i].offset]); 169 } 170 } 171 if (static_cast<SpvOp>(inst->opcode) == SpvOpFunctionCall) { 172 _.AddFunctionCallTarget(inst->words[3]); 173 } 174 175 DebugInstructionPass(_, inst); 176 if (auto error = CapabilityPass(_, inst)) return error; 177 if (auto error = DataRulesPass(_, inst)) return error; 178 if (auto error = IdPass(_, inst)) return error; 179 if (auto error = ModuleLayoutPass(_, inst)) return error; 180 if (auto error = CfgPass(_, inst)) return error; 181 if (auto error = InstructionPass(_, inst)) return error; 182 if (auto error = TypeUniquePass(_, inst)) return error; 183 184 return SPV_SUCCESS; 185 } 186 187 void printDot(const ValidationState_t& _, const libspirv::BasicBlock& other) { 188 string block_string; 189 if (other.successors()->empty()) { 190 block_string += "end "; 191 } else { 192 for (auto block : *other.successors()) { 193 block_string += _.getIdOrName(block->id()) + " "; 194 } 195 } 196 printf("%10s -> {%s\b}\n", _.getIdOrName(other.id()).c_str(), 197 block_string.c_str()); 198 } 199 200 void PrintBlocks(ValidationState_t& _, libspirv::Function func) { 201 assert(func.first_block()); 202 203 printf("%10s -> %s\n", _.getIdOrName(func.id()).c_str(), 204 _.getIdOrName(func.first_block()->id()).c_str()); 205 for (const auto& block : func.ordered_blocks()) { 206 printDot(_, *block); 207 } 208 } 209 210 #ifdef __clang__ 211 #define UNUSED(func) [[gnu::unused]] func 212 #elif defined(__GNUC__) 213 #define UNUSED(func) \ 214 func __attribute__((unused)); \ 215 func 216 #elif defined(_MSC_VER) 217 #define UNUSED(func) func 218 #endif 219 220 UNUSED(void PrintDotGraph(ValidationState_t& _, libspirv::Function func)) { 221 if (func.first_block()) { 222 string func_name(_.getIdOrName(func.id())); 223 printf("digraph %s {\n", func_name.c_str()); 224 PrintBlocks(_, func); 225 printf("}\n"); 226 } 227 } 228 } // anonymous namespace 229 230 spv_result_t spvValidate(const spv_const_context context, 231 const spv_const_binary binary, 232 spv_diagnostic* pDiagnostic) { 233 return spvValidateBinary(context, binary->code, binary->wordCount, 234 pDiagnostic); 235 } 236 237 spv_result_t ValidateBinaryUsingContextAndValidationState( 238 const spv_context_t& context, const uint32_t* words, const size_t num_words, 239 spv_diagnostic* pDiagnostic, ValidationState_t* vstate) { 240 auto binary = std::unique_ptr<spv_const_binary_t>( 241 new spv_const_binary_t{words, num_words}); 242 243 spv_endianness_t endian; 244 spv_position_t position = {}; 245 if (spvBinaryEndianness(binary.get(), &endian)) { 246 return libspirv::DiagnosticStream(position, context.consumer, 247 SPV_ERROR_INVALID_BINARY) 248 << "Invalid SPIR-V magic number."; 249 } 250 251 spv_header_t header; 252 if (spvBinaryHeaderGet(binary.get(), endian, &header)) { 253 return libspirv::DiagnosticStream(position, context.consumer, 254 SPV_ERROR_INVALID_BINARY) 255 << "Invalid SPIR-V header."; 256 } 257 258 // Look for OpExtension instructions and register extensions. 259 // Diagnostics if any will be produced in the next pass (ProcessInstruction). 260 spvBinaryParse(&context, vstate, words, num_words, 261 /* parsed_header = */ nullptr, ProcessExtensions, 262 /* diagnostic = */ nullptr); 263 264 // NOTE: Parse the module and perform inline validation checks. These 265 // checks do not require the the knowledge of the whole module. 266 if (auto error = spvBinaryParse(&context, vstate, words, num_words, 267 setHeader, ProcessInstruction, pDiagnostic)) 268 return error; 269 270 if (vstate->in_function_body()) 271 return vstate->diag(SPV_ERROR_INVALID_LAYOUT) 272 << "Missing OpFunctionEnd at end of module."; 273 274 // TODO(umar): Add validation checks which require the parsing of the entire 275 // module. Use the information from the ProcessInstruction pass to make the 276 // checks. 277 if (vstate->unresolved_forward_id_count() > 0) { 278 stringstream ss; 279 vector<uint32_t> ids = vstate->UnresolvedForwardIds(); 280 281 transform(begin(ids), end(ids), ostream_iterator<string>(ss, " "), 282 bind(&ValidationState_t::getIdName, std::ref(*vstate), _1)); 283 284 auto id_str = ss.str(); 285 return vstate->diag(SPV_ERROR_INVALID_ID) 286 << "The following forward referenced IDs have not been defined:\n" 287 << id_str.substr(0, id_str.size() - 1); 288 } 289 290 // CFG checks are performed after the binary has been parsed 291 // and the CFGPass has collected information about the control flow 292 if (auto error = PerformCfgChecks(*vstate)) return error; 293 if (auto error = UpdateIdUse(*vstate)) return error; 294 if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error; 295 if (auto error = ValidateDecorations(*vstate)) return error; 296 297 // Entry point validation. Based on 2.16.1 (Universal Validation Rules) of the 298 // SPIRV spec: 299 // * There is at least one OpEntryPoint instruction, unless the Linkage 300 // capability is being used. 301 // * No function can be targeted by both an OpEntryPoint instruction and an 302 // OpFunctionCall instruction. 303 if (vstate->entry_points().empty() && 304 !vstate->HasCapability(SpvCapabilityLinkage)) { 305 return vstate->diag(SPV_ERROR_INVALID_BINARY) 306 << "No OpEntryPoint instruction was found. This is only allowed if " 307 "the Linkage capability is being used."; 308 } 309 for (const auto& entry_point : vstate->entry_points()) { 310 if (vstate->IsFunctionCallTarget(entry_point)) { 311 return vstate->diag(SPV_ERROR_INVALID_BINARY) 312 << "A function (" << entry_point 313 << ") may not be targeted by both an OpEntryPoint instruction and " 314 "an OpFunctionCall instruction."; 315 } 316 } 317 318 // NOTE: Copy each instruction for easier processing 319 std::vector<spv_instruction_t> instructions; 320 // Expect average instruction length to be a bit over 2 words. 321 instructions.reserve(binary->wordCount / 2); 322 uint64_t index = SPV_INDEX_INSTRUCTION; 323 while (index < binary->wordCount) { 324 uint16_t wordCount; 325 uint16_t opcode; 326 spvOpcodeSplit(spvFixWord(binary->code[index], endian), &wordCount, 327 &opcode); 328 spv_instruction_t inst; 329 spvInstructionCopy(&binary->code[index], static_cast<SpvOp>(opcode), 330 wordCount, endian, &inst); 331 instructions.emplace_back(std::move(inst)); 332 index += wordCount; 333 } 334 335 position.index = SPV_INDEX_INSTRUCTION; 336 return spvValidateIDs(instructions.data(), instructions.size(), 337 context.opcode_table, 338 context.operand_table, 339 context.ext_inst_table, *vstate, &position); 340 } 341 342 spv_result_t spvValidateBinary(const spv_const_context context, 343 const uint32_t* words, const size_t num_words, 344 spv_diagnostic* pDiagnostic) { 345 spv_context_t hijack_context = *context; 346 if (pDiagnostic) { 347 *pDiagnostic = nullptr; 348 libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); 349 } 350 351 // This interface is used for default command line options. 352 spv_validator_options default_options = spvValidatorOptionsCreate(); 353 354 // Create the ValidationState using the context and default options. 355 ValidationState_t vstate(&hijack_context, default_options); 356 357 spv_result_t result = ValidateBinaryUsingContextAndValidationState( 358 hijack_context, words, num_words, pDiagnostic, &vstate); 359 360 spvValidatorOptionsDestroy(default_options); 361 return result; 362 } 363 364 spv_result_t spvValidateWithOptions(const spv_const_context context, 365 spv_const_validator_options options, 366 const spv_const_binary binary, 367 spv_diagnostic* pDiagnostic) { 368 spv_context_t hijack_context = *context; 369 if (pDiagnostic) { 370 *pDiagnostic = nullptr; 371 libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); 372 } 373 374 // Create the ValidationState using the context. 375 ValidationState_t vstate(&hijack_context, options); 376 377 return ValidateBinaryUsingContextAndValidationState( 378 hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate); 379 } 380 381 namespace spvtools { 382 383 spv_result_t ValidateBinaryAndKeepValidationState( 384 const spv_const_context context, spv_const_validator_options options, 385 const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic, 386 std::unique_ptr<ValidationState_t>* vstate) { 387 spv_context_t hijack_context = *context; 388 if (pDiagnostic) { 389 *pDiagnostic = nullptr; 390 libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic); 391 } 392 393 vstate->reset(new ValidationState_t(&hijack_context, options)); 394 395 return ValidateBinaryUsingContextAndValidationState( 396 hijack_context, words, num_words, pDiagnostic, vstate->get()); 397 } 398 399 spv_result_t ValidateInstructionAndUpdateValidationState( 400 ValidationState_t* vstate, const spv_parsed_instruction_t* inst) { 401 return ProcessInstruction(vstate, inst); 402 } 403 404 } // namespace spvtools 405