Home | History | Annotate | Download | only in source
      1 // Copyright (c) 2015-2016 The Khronos Group Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "validate.h"
     16 
     17 #include <cassert>
     18 #include <cstdio>
     19 
     20 #include <algorithm>
     21 #include <functional>
     22 #include <iterator>
     23 #include <memory>
     24 #include <sstream>
     25 #include <string>
     26 #include <vector>
     27 
     28 #include "binary.h"
     29 #include "diagnostic.h"
     30 #include "enum_string_mapping.h"
     31 #include "extensions.h"
     32 #include "instruction.h"
     33 #include "opcode.h"
     34 #include "operand.h"
     35 #include "spirv-tools/libspirv.h"
     36 #include "spirv_constant.h"
     37 #include "spirv_endian.h"
     38 #include "spirv_validator_options.h"
     39 #include "val/construct.h"
     40 #include "val/function.h"
     41 #include "val/validation_state.h"
     42 
     43 using std::function;
     44 using std::ostream_iterator;
     45 using std::placeholders::_1;
     46 using std::string;
     47 using std::stringstream;
     48 using std::transform;
     49 using std::vector;
     50 
     51 using libspirv::CfgPass;
     52 using libspirv::Extension;
     53 using libspirv::InstructionPass;
     54 using libspirv::ModuleLayoutPass;
     55 using libspirv::DataRulesPass;
     56 using libspirv::IdPass;
     57 using libspirv::ValidationState_t;
     58 
     59 spv_result_t spvValidateIDs(const spv_instruction_t* pInsts,
     60                             const uint64_t count,
     61                             const spv_opcode_table opcodeTable,
     62                             const spv_operand_table operandTable,
     63                             const spv_ext_inst_table extInstTable,
     64                             const ValidationState_t& state,
     65                             spv_position position) {
     66   position->index = SPV_INDEX_INSTRUCTION;
     67   if (auto error =
     68           spvValidateInstructionIDs(pInsts, count, opcodeTable, operandTable,
     69                                     extInstTable, state, position))
     70     return error;
     71   return SPV_SUCCESS;
     72 }
     73 
     74 namespace {
     75 
     76 // TODO(umar): Validate header
     77 // TODO(umar): The binary parser validates the magic word, and the length of the
     78 // header, but nothing else.
     79 spv_result_t setHeader(void* user_data, spv_endianness_t endian, uint32_t magic,
     80                        uint32_t version, uint32_t generator, uint32_t id_bound,
     81                        uint32_t reserved) {
     82   // Record the ID bound so that the validator can ensure no ID is out of bound.
     83   ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
     84   _.setIdBound(id_bound);
     85 
     86   (void)endian;
     87   (void)magic;
     88   (void)version;
     89   (void)generator;
     90   (void)id_bound;
     91   (void)reserved;
     92   return SPV_SUCCESS;
     93 }
     94 
     95 // Improves diagnostic messages by collecting names of IDs
     96 // NOTE: This function returns void and is not involved in validation
     97 void DebugInstructionPass(ValidationState_t& _,
     98                           const spv_parsed_instruction_t* inst) {
     99   switch (inst->opcode) {
    100     case SpvOpName: {
    101       const uint32_t target = *(inst->words + inst->operands[0].offset);
    102       const char* str =
    103           reinterpret_cast<const char*>(inst->words + inst->operands[1].offset);
    104       _.AssignNameToId(target, str);
    105     } break;
    106     case SpvOpMemberName: {
    107       const uint32_t target = *(inst->words + inst->operands[0].offset);
    108       const char* str =
    109           reinterpret_cast<const char*>(inst->words + inst->operands[2].offset);
    110       _.AssignNameToId(target, str);
    111     } break;
    112     case SpvOpSourceContinued:
    113     case SpvOpSource:
    114     case SpvOpSourceExtension:
    115     case SpvOpString:
    116     case SpvOpLine:
    117     case SpvOpNoLine:
    118 
    119     default:
    120       break;
    121   }
    122 }
    123 
    124 // Parses OpExtension instruction and registers extension.
    125 void RegisterExtension(ValidationState_t& _,
    126                        const spv_parsed_instruction_t* inst) {
    127   const std::string extension_str = libspirv::GetExtensionString(inst);
    128   Extension extension;
    129   if (!GetExtensionFromString(extension_str, &extension)) {
    130     // The error will be logged in the ProcessInstruction pass.
    131     return;
    132   }
    133 
    134   _.RegisterExtension(extension);
    135 }
    136 
    137 // Parses the beginning of the module searching for OpExtension instructions.
    138 // Registers extensions if recognized. Returns SPV_REQUESTED_TERMINATION
    139 // once an instruction which is not SpvOpCapability and SpvOpExtension is
    140 // encountered. According to the SPIR-V spec extensions are declared after
    141 // capabilities and before everything else.
    142 spv_result_t ProcessExtensions(
    143     void* user_data, const spv_parsed_instruction_t* inst) {
    144   const SpvOp opcode = static_cast<SpvOp>(inst->opcode);
    145   if (opcode == SpvOpCapability)
    146     return SPV_SUCCESS;
    147 
    148   if (opcode == SpvOpExtension) {
    149     ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
    150     RegisterExtension(_, inst);
    151     return SPV_SUCCESS;
    152   }
    153 
    154   // OpExtension block is finished, requesting termination.
    155   return SPV_REQUESTED_TERMINATION;
    156 }
    157 
    158 spv_result_t ProcessInstruction(void* user_data,
    159                                 const spv_parsed_instruction_t* inst) {
    160   ValidationState_t& _ = *(reinterpret_cast<ValidationState_t*>(user_data));
    161   _.increment_instruction_count();
    162   if (static_cast<SpvOp>(inst->opcode) == SpvOpEntryPoint) {
    163     const auto entry_point = inst->words[2];
    164     _.RegisterEntryPointId(entry_point);
    165     // Operand 3 and later are the <id> of interfaces for the entry point.
    166     for (int i = 3; i < inst->num_operands; ++i) {
    167       _.RegisterInterfaceForEntryPoint(entry_point,
    168                                        inst->words[inst->operands[i].offset]);
    169     }
    170   }
    171   if (static_cast<SpvOp>(inst->opcode) == SpvOpFunctionCall) {
    172     _.AddFunctionCallTarget(inst->words[3]);
    173   }
    174 
    175   DebugInstructionPass(_, inst);
    176   if (auto error = CapabilityPass(_, inst)) return error;
    177   if (auto error = DataRulesPass(_, inst)) return error;
    178   if (auto error = IdPass(_, inst)) return error;
    179   if (auto error = ModuleLayoutPass(_, inst)) return error;
    180   if (auto error = CfgPass(_, inst)) return error;
    181   if (auto error = InstructionPass(_, inst)) return error;
    182   if (auto error = TypeUniquePass(_, inst)) return error;
    183 
    184   return SPV_SUCCESS;
    185 }
    186 
    187 void printDot(const ValidationState_t& _, const libspirv::BasicBlock& other) {
    188   string block_string;
    189   if (other.successors()->empty()) {
    190     block_string += "end ";
    191   } else {
    192     for (auto block : *other.successors()) {
    193       block_string += _.getIdOrName(block->id()) + " ";
    194     }
    195   }
    196   printf("%10s -> {%s\b}\n", _.getIdOrName(other.id()).c_str(),
    197          block_string.c_str());
    198 }
    199 
    200 void PrintBlocks(ValidationState_t& _, libspirv::Function func) {
    201   assert(func.first_block());
    202 
    203   printf("%10s -> %s\n", _.getIdOrName(func.id()).c_str(),
    204          _.getIdOrName(func.first_block()->id()).c_str());
    205   for (const auto& block : func.ordered_blocks()) {
    206     printDot(_, *block);
    207   }
    208 }
    209 
    210 #ifdef __clang__
    211 #define UNUSED(func) [[gnu::unused]] func
    212 #elif defined(__GNUC__)
    213 #define UNUSED(func)            \
    214   func __attribute__((unused)); \
    215   func
    216 #elif defined(_MSC_VER)
    217 #define UNUSED(func) func
    218 #endif
    219 
    220 UNUSED(void PrintDotGraph(ValidationState_t& _, libspirv::Function func)) {
    221   if (func.first_block()) {
    222     string func_name(_.getIdOrName(func.id()));
    223     printf("digraph %s {\n", func_name.c_str());
    224     PrintBlocks(_, func);
    225     printf("}\n");
    226   }
    227 }
    228 }  // anonymous namespace
    229 
    230 spv_result_t spvValidate(const spv_const_context context,
    231                          const spv_const_binary binary,
    232                          spv_diagnostic* pDiagnostic) {
    233   return spvValidateBinary(context, binary->code, binary->wordCount,
    234                            pDiagnostic);
    235 }
    236 
    237 spv_result_t ValidateBinaryUsingContextAndValidationState(
    238     const spv_context_t& context, const uint32_t* words, const size_t num_words,
    239     spv_diagnostic* pDiagnostic, ValidationState_t* vstate) {
    240   auto binary = std::unique_ptr<spv_const_binary_t>(
    241     new spv_const_binary_t{words, num_words});
    242 
    243   spv_endianness_t endian;
    244   spv_position_t position = {};
    245   if (spvBinaryEndianness(binary.get(), &endian)) {
    246     return libspirv::DiagnosticStream(position, context.consumer,
    247                                       SPV_ERROR_INVALID_BINARY)
    248            << "Invalid SPIR-V magic number.";
    249   }
    250 
    251   spv_header_t header;
    252   if (spvBinaryHeaderGet(binary.get(), endian, &header)) {
    253     return libspirv::DiagnosticStream(position, context.consumer,
    254                                       SPV_ERROR_INVALID_BINARY)
    255            << "Invalid SPIR-V header.";
    256   }
    257 
    258   // Look for OpExtension instructions and register extensions.
    259   // Diagnostics if any will be produced in the next pass (ProcessInstruction).
    260   spvBinaryParse(&context, vstate, words, num_words,
    261                  /* parsed_header = */ nullptr, ProcessExtensions,
    262                  /* diagnostic = */ nullptr);
    263 
    264   // NOTE: Parse the module and perform inline validation checks. These
    265   // checks do not require the the knowledge of the whole module.
    266   if (auto error = spvBinaryParse(&context, vstate, words, num_words,
    267                                   setHeader, ProcessInstruction, pDiagnostic))
    268     return error;
    269 
    270   if (vstate->in_function_body())
    271     return vstate->diag(SPV_ERROR_INVALID_LAYOUT)
    272            << "Missing OpFunctionEnd at end of module.";
    273 
    274   // TODO(umar): Add validation checks which require the parsing of the entire
    275   // module. Use the information from the ProcessInstruction pass to make the
    276   // checks.
    277   if (vstate->unresolved_forward_id_count() > 0) {
    278     stringstream ss;
    279     vector<uint32_t> ids = vstate->UnresolvedForwardIds();
    280 
    281     transform(begin(ids), end(ids), ostream_iterator<string>(ss, " "),
    282               bind(&ValidationState_t::getIdName, std::ref(*vstate), _1));
    283 
    284     auto id_str = ss.str();
    285     return vstate->diag(SPV_ERROR_INVALID_ID)
    286            << "The following forward referenced IDs have not been defined:\n"
    287            << id_str.substr(0, id_str.size() - 1);
    288   }
    289 
    290   // CFG checks are performed after the binary has been parsed
    291   // and the CFGPass has collected information about the control flow
    292   if (auto error = PerformCfgChecks(*vstate)) return error;
    293   if (auto error = UpdateIdUse(*vstate)) return error;
    294   if (auto error = CheckIdDefinitionDominateUse(*vstate)) return error;
    295   if (auto error = ValidateDecorations(*vstate))  return error;
    296 
    297   // Entry point validation. Based on 2.16.1 (Universal Validation Rules) of the
    298   // SPIRV spec:
    299   // * There is at least one OpEntryPoint instruction, unless the Linkage
    300   // capability is being used.
    301   // * No function can be targeted by both an OpEntryPoint instruction and an
    302   // OpFunctionCall instruction.
    303   if (vstate->entry_points().empty() &&
    304       !vstate->HasCapability(SpvCapabilityLinkage)) {
    305     return vstate->diag(SPV_ERROR_INVALID_BINARY)
    306            << "No OpEntryPoint instruction was found. This is only allowed if "
    307               "the Linkage capability is being used.";
    308   }
    309   for (const auto& entry_point : vstate->entry_points()) {
    310     if (vstate->IsFunctionCallTarget(entry_point)) {
    311       return vstate->diag(SPV_ERROR_INVALID_BINARY)
    312              << "A function (" << entry_point
    313              << ") may not be targeted by both an OpEntryPoint instruction and "
    314                 "an OpFunctionCall instruction.";
    315     }
    316   }
    317 
    318   // NOTE: Copy each instruction for easier processing
    319   std::vector<spv_instruction_t> instructions;
    320   // Expect average instruction length to be a bit over 2 words.
    321   instructions.reserve(binary->wordCount / 2);
    322   uint64_t index = SPV_INDEX_INSTRUCTION;
    323   while (index < binary->wordCount) {
    324     uint16_t wordCount;
    325     uint16_t opcode;
    326     spvOpcodeSplit(spvFixWord(binary->code[index], endian), &wordCount,
    327                    &opcode);
    328     spv_instruction_t inst;
    329     spvInstructionCopy(&binary->code[index], static_cast<SpvOp>(opcode),
    330                        wordCount, endian, &inst);
    331     instructions.emplace_back(std::move(inst));
    332     index += wordCount;
    333   }
    334 
    335   position.index = SPV_INDEX_INSTRUCTION;
    336   return spvValidateIDs(instructions.data(), instructions.size(),
    337                         context.opcode_table,
    338                         context.operand_table,
    339                         context.ext_inst_table, *vstate, &position);
    340 }
    341 
    342 spv_result_t spvValidateBinary(const spv_const_context context,
    343                                const uint32_t* words, const size_t num_words,
    344                                spv_diagnostic* pDiagnostic) {
    345   spv_context_t hijack_context = *context;
    346   if (pDiagnostic) {
    347     *pDiagnostic = nullptr;
    348     libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
    349   }
    350 
    351   // This interface is used for default command line options.
    352   spv_validator_options default_options = spvValidatorOptionsCreate();
    353 
    354   // Create the ValidationState using the context and default options.
    355   ValidationState_t vstate(&hijack_context, default_options);
    356 
    357   spv_result_t result = ValidateBinaryUsingContextAndValidationState(
    358       hijack_context, words, num_words, pDiagnostic, &vstate);
    359 
    360   spvValidatorOptionsDestroy(default_options);
    361   return result;
    362 }
    363 
    364 spv_result_t spvValidateWithOptions(const spv_const_context context,
    365                                     spv_const_validator_options options,
    366                                     const spv_const_binary binary,
    367                                     spv_diagnostic* pDiagnostic) {
    368   spv_context_t hijack_context = *context;
    369   if (pDiagnostic) {
    370     *pDiagnostic = nullptr;
    371     libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
    372   }
    373 
    374   // Create the ValidationState using the context.
    375   ValidationState_t vstate(&hijack_context, options);
    376 
    377   return ValidateBinaryUsingContextAndValidationState(
    378       hijack_context, binary->code, binary->wordCount, pDiagnostic, &vstate);
    379 }
    380 
    381 namespace spvtools {
    382 
    383 spv_result_t ValidateBinaryAndKeepValidationState(
    384     const spv_const_context context, spv_const_validator_options options,
    385     const uint32_t* words, const size_t num_words, spv_diagnostic* pDiagnostic,
    386     std::unique_ptr<ValidationState_t>* vstate) {
    387   spv_context_t hijack_context = *context;
    388   if (pDiagnostic) {
    389     *pDiagnostic = nullptr;
    390     libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, pDiagnostic);
    391   }
    392 
    393   vstate->reset(new ValidationState_t(&hijack_context, options));
    394 
    395   return ValidateBinaryUsingContextAndValidationState(
    396       hijack_context, words, num_words, pDiagnostic, vstate->get());
    397 }
    398 
    399 spv_result_t ValidateInstructionAndUpdateValidationState(
    400     ValidationState_t* vstate, const spv_parsed_instruction_t* inst) {
    401   return ProcessInstruction(vstate, inst);
    402 }
    403 
    404 }  // namespace spvtools
    405