Home | History | Annotate | Download | only in val
      1 // Copyright (c) 2018 Google LLC.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/val/validate.h"
     16 
     17 #include "source/opcode.h"
     18 #include "source/val/instruction.h"
     19 #include "source/val/validation_state.h"
     20 
     21 namespace spvtools {
     22 namespace val {
     23 namespace {
     24 
     25 spv_result_t ValidateConstantBool(ValidationState_t& _,
     26                                   const Instruction* inst) {
     27   auto type = _.FindDef(inst->type_id());
     28   if (!type || type->opcode() != SpvOpTypeBool) {
     29     return _.diag(SPV_ERROR_INVALID_ID, inst)
     30            << "Op" << spvOpcodeString(inst->opcode()) << " Result Type <id> '"
     31            << _.getIdName(inst->type_id()) << "' is not a boolean type.";
     32   }
     33 
     34   return SPV_SUCCESS;
     35 }
     36 
     37 spv_result_t ValidateConstantComposite(ValidationState_t& _,
     38                                        const Instruction* inst) {
     39   std::string opcode_name = std::string("Op") + spvOpcodeString(inst->opcode());
     40 
     41   const auto result_type = _.FindDef(inst->type_id());
     42   if (!result_type || !spvOpcodeIsComposite(result_type->opcode())) {
     43     return _.diag(SPV_ERROR_INVALID_ID, inst)
     44            << opcode_name << " Result Type <id> '"
     45            << _.getIdName(inst->type_id()) << "' is not a composite type.";
     46   }
     47 
     48   const auto constituent_count = inst->words().size() - 3;
     49   switch (result_type->opcode()) {
     50     case SpvOpTypeVector: {
     51       const auto component_count = result_type->GetOperandAs<uint32_t>(2);
     52       if (component_count != constituent_count) {
     53         // TODO: Output ID's on diagnostic
     54         return _.diag(SPV_ERROR_INVALID_ID, inst)
     55                << opcode_name
     56                << " Constituent <id> count does not match "
     57                   "Result Type <id> '"
     58                << _.getIdName(result_type->id())
     59                << "'s vector component count.";
     60       }
     61       const auto component_type =
     62           _.FindDef(result_type->GetOperandAs<uint32_t>(1));
     63       if (!component_type) {
     64         return _.diag(SPV_ERROR_INVALID_ID, result_type)
     65                << "Component type is not defined.";
     66       }
     67       for (size_t constituent_index = 2;
     68            constituent_index < inst->operands().size(); constituent_index++) {
     69         const auto constituent_id =
     70             inst->GetOperandAs<uint32_t>(constituent_index);
     71         const auto constituent = _.FindDef(constituent_id);
     72         if (!constituent ||
     73             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
     74           return _.diag(SPV_ERROR_INVALID_ID, inst)
     75                  << opcode_name << " Constituent <id> '"
     76                  << _.getIdName(constituent_id)
     77                  << "' is not a constant or undef.";
     78         }
     79         const auto constituent_result_type = _.FindDef(constituent->type_id());
     80         if (!constituent_result_type ||
     81             component_type->opcode() != constituent_result_type->opcode()) {
     82           return _.diag(SPV_ERROR_INVALID_ID, inst)
     83                  << opcode_name << " Constituent <id> '"
     84                  << _.getIdName(constituent_id)
     85                  << "'s type does not match Result Type <id> '"
     86                  << _.getIdName(result_type->id()) << "'s vector element type.";
     87         }
     88       }
     89     } break;
     90     case SpvOpTypeMatrix: {
     91       const auto column_count = result_type->GetOperandAs<uint32_t>(2);
     92       if (column_count != constituent_count) {
     93         // TODO: Output ID's on diagnostic
     94         return _.diag(SPV_ERROR_INVALID_ID, inst)
     95                << opcode_name
     96                << " Constituent <id> count does not match "
     97                   "Result Type <id> '"
     98                << _.getIdName(result_type->id()) << "'s matrix column count.";
     99       }
    100 
    101       const auto column_type = _.FindDef(result_type->words()[2]);
    102       if (!column_type) {
    103         return _.diag(SPV_ERROR_INVALID_ID, result_type)
    104                << "Column type is not defined.";
    105       }
    106       const auto component_count = column_type->GetOperandAs<uint32_t>(2);
    107       const auto component_type =
    108           _.FindDef(column_type->GetOperandAs<uint32_t>(1));
    109       if (!component_type) {
    110         return _.diag(SPV_ERROR_INVALID_ID, column_type)
    111                << "Component type is not defined.";
    112       }
    113 
    114       for (size_t constituent_index = 2;
    115            constituent_index < inst->operands().size(); constituent_index++) {
    116         const auto constituent_id =
    117             inst->GetOperandAs<uint32_t>(constituent_index);
    118         const auto constituent = _.FindDef(constituent_id);
    119         if (!constituent ||
    120             !(SpvOpConstantComposite == constituent->opcode() ||
    121               SpvOpSpecConstantComposite == constituent->opcode() ||
    122               SpvOpUndef == constituent->opcode())) {
    123           // The message says "... or undef" because the spec does not say
    124           // undef is a constant.
    125           return _.diag(SPV_ERROR_INVALID_ID, inst)
    126                  << opcode_name << " Constituent <id> '"
    127                  << _.getIdName(constituent_id)
    128                  << "' is not a constant composite or undef.";
    129         }
    130         const auto vector = _.FindDef(constituent->type_id());
    131         if (!vector) {
    132           return _.diag(SPV_ERROR_INVALID_ID, constituent)
    133                  << "Result type is not defined.";
    134         }
    135         if (column_type->opcode() != vector->opcode()) {
    136           return _.diag(SPV_ERROR_INVALID_ID, inst)
    137                  << opcode_name << " Constituent <id> '"
    138                  << _.getIdName(constituent_id)
    139                  << "' type does not match Result Type <id> '"
    140                  << _.getIdName(result_type->id()) << "'s matrix column type.";
    141         }
    142         const auto vector_component_type =
    143             _.FindDef(vector->GetOperandAs<uint32_t>(1));
    144         if (component_type->id() != vector_component_type->id()) {
    145           return _.diag(SPV_ERROR_INVALID_ID, inst)
    146                  << opcode_name << " Constituent <id> '"
    147                  << _.getIdName(constituent_id)
    148                  << "' component type does not match Result Type <id> '"
    149                  << _.getIdName(result_type->id())
    150                  << "'s matrix column component type.";
    151         }
    152         if (component_count != vector->words()[3]) {
    153           return _.diag(SPV_ERROR_INVALID_ID, inst)
    154                  << opcode_name << " Constituent <id> '"
    155                  << _.getIdName(constituent_id)
    156                  << "' vector component count does not match Result Type <id> '"
    157                  << _.getIdName(result_type->id())
    158                  << "'s vector component count.";
    159         }
    160       }
    161     } break;
    162     case SpvOpTypeArray: {
    163       auto element_type = _.FindDef(result_type->GetOperandAs<uint32_t>(1));
    164       if (!element_type) {
    165         return _.diag(SPV_ERROR_INVALID_ID, result_type)
    166                << "Element type is not defined.";
    167       }
    168       const auto length = _.FindDef(result_type->GetOperandAs<uint32_t>(2));
    169       if (!length) {
    170         return _.diag(SPV_ERROR_INVALID_ID, result_type)
    171                << "Length is not defined.";
    172       }
    173       bool is_int32;
    174       bool is_const;
    175       uint32_t value;
    176       std::tie(is_int32, is_const, value) = _.EvalInt32IfConst(length->id());
    177       if (is_int32 && is_const && value != constituent_count) {
    178         return _.diag(SPV_ERROR_INVALID_ID, inst)
    179                << opcode_name
    180                << " Constituent count does not match "
    181                   "Result Type <id> '"
    182                << _.getIdName(result_type->id()) << "'s array length.";
    183       }
    184       for (size_t constituent_index = 2;
    185            constituent_index < inst->operands().size(); constituent_index++) {
    186         const auto constituent_id =
    187             inst->GetOperandAs<uint32_t>(constituent_index);
    188         const auto constituent = _.FindDef(constituent_id);
    189         if (!constituent ||
    190             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
    191           return _.diag(SPV_ERROR_INVALID_ID, inst)
    192                  << opcode_name << " Constituent <id> '"
    193                  << _.getIdName(constituent_id)
    194                  << "' is not a constant or undef.";
    195         }
    196         const auto constituent_type = _.FindDef(constituent->type_id());
    197         if (!constituent_type) {
    198           return _.diag(SPV_ERROR_INVALID_ID, constituent)
    199                  << "Result type is not defined.";
    200         }
    201         if (element_type->id() != constituent_type->id()) {
    202           return _.diag(SPV_ERROR_INVALID_ID, inst)
    203                  << opcode_name << " Constituent <id> '"
    204                  << _.getIdName(constituent_id)
    205                  << "'s type does not match Result Type <id> '"
    206                  << _.getIdName(result_type->id()) << "'s array element type.";
    207         }
    208       }
    209     } break;
    210     case SpvOpTypeStruct: {
    211       const auto member_count = result_type->words().size() - 2;
    212       if (member_count != constituent_count) {
    213         return _.diag(SPV_ERROR_INVALID_ID, inst)
    214                << opcode_name << " Constituent <id> '"
    215                << _.getIdName(inst->type_id())
    216                << "' count does not match Result Type <id> '"
    217                << _.getIdName(result_type->id()) << "'s struct member count.";
    218       }
    219       for (uint32_t constituent_index = 2, member_index = 1;
    220            constituent_index < inst->operands().size();
    221            constituent_index++, member_index++) {
    222         const auto constituent_id =
    223             inst->GetOperandAs<uint32_t>(constituent_index);
    224         const auto constituent = _.FindDef(constituent_id);
    225         if (!constituent ||
    226             !spvOpcodeIsConstantOrUndef(constituent->opcode())) {
    227           return _.diag(SPV_ERROR_INVALID_ID, inst)
    228                  << opcode_name << " Constituent <id> '"
    229                  << _.getIdName(constituent_id)
    230                  << "' is not a constant or undef.";
    231         }
    232         const auto constituent_type = _.FindDef(constituent->type_id());
    233         if (!constituent_type) {
    234           return _.diag(SPV_ERROR_INVALID_ID, constituent)
    235                  << "Result type is not defined.";
    236         }
    237 
    238         const auto member_type_id =
    239             result_type->GetOperandAs<uint32_t>(member_index);
    240         const auto member_type = _.FindDef(member_type_id);
    241         if (!member_type || member_type->id() != constituent_type->id()) {
    242           return _.diag(SPV_ERROR_INVALID_ID, inst)
    243                  << opcode_name << " Constituent <id> '"
    244                  << _.getIdName(constituent_id)
    245                  << "' type does not match the Result Type <id> '"
    246                  << _.getIdName(result_type->id()) << "'s member type.";
    247         }
    248       }
    249     } break;
    250     default:
    251       break;
    252   }
    253   return SPV_SUCCESS;
    254 }
    255 
    256 spv_result_t ValidateConstantSampler(ValidationState_t& _,
    257                                      const Instruction* inst) {
    258   const auto result_type = _.FindDef(inst->type_id());
    259   if (!result_type || result_type->opcode() != SpvOpTypeSampler) {
    260     return _.diag(SPV_ERROR_INVALID_ID, result_type)
    261            << "OpConstantSampler Result Type <id> '"
    262            << _.getIdName(inst->type_id()) << "' is not a sampler type.";
    263   }
    264 
    265   return SPV_SUCCESS;
    266 }
    267 
    268 // True if instruction defines a type that can have a null value, as defined by
    269 // the SPIR-V spec.  Tracks composite-type components through module to check
    270 // nullability transitively.
    271 bool IsTypeNullable(const std::vector<uint32_t>& instruction,
    272                     const ValidationState_t& _) {
    273   uint16_t opcode;
    274   uint16_t word_count;
    275   spvOpcodeSplit(instruction[0], &word_count, &opcode);
    276   switch (static_cast<SpvOp>(opcode)) {
    277     case SpvOpTypeBool:
    278     case SpvOpTypeInt:
    279     case SpvOpTypeFloat:
    280     case SpvOpTypePointer:
    281     case SpvOpTypeEvent:
    282     case SpvOpTypeDeviceEvent:
    283     case SpvOpTypeReserveId:
    284     case SpvOpTypeQueue:
    285       return true;
    286     case SpvOpTypeArray:
    287     case SpvOpTypeMatrix:
    288     case SpvOpTypeVector: {
    289       auto base_type = _.FindDef(instruction[2]);
    290       return base_type && IsTypeNullable(base_type->words(), _);
    291     }
    292     case SpvOpTypeStruct: {
    293       for (size_t elementIndex = 2; elementIndex < instruction.size();
    294            ++elementIndex) {
    295         auto element = _.FindDef(instruction[elementIndex]);
    296         if (!element || !IsTypeNullable(element->words(), _)) return false;
    297       }
    298       return true;
    299     }
    300     default:
    301       return false;
    302   }
    303 }
    304 
    305 spv_result_t ValidateConstantNull(ValidationState_t& _,
    306                                   const Instruction* inst) {
    307   const auto result_type = _.FindDef(inst->type_id());
    308   if (!result_type || !IsTypeNullable(result_type->words(), _)) {
    309     return _.diag(SPV_ERROR_INVALID_ID, inst)
    310            << "OpConstantNull Result Type <id> '"
    311            << _.getIdName(inst->type_id()) << "' cannot have a null value.";
    312   }
    313 
    314   return SPV_SUCCESS;
    315 }
    316 
    317 }  // namespace
    318 
    319 spv_result_t ConstantPass(ValidationState_t& _, const Instruction* inst) {
    320   switch (inst->opcode()) {
    321     case SpvOpConstantTrue:
    322     case SpvOpConstantFalse:
    323     case SpvOpSpecConstantTrue:
    324     case SpvOpSpecConstantFalse:
    325       if (auto error = ValidateConstantBool(_, inst)) return error;
    326       break;
    327     case SpvOpConstantComposite:
    328     case SpvOpSpecConstantComposite:
    329       if (auto error = ValidateConstantComposite(_, inst)) return error;
    330       break;
    331     case SpvOpConstantSampler:
    332       if (auto error = ValidateConstantSampler(_, inst)) return error;
    333       break;
    334     case SpvOpConstantNull:
    335       if (auto error = ValidateConstantNull(_, inst)) return error;
    336       break;
    337     default:
    338       break;
    339   }
    340 
    341   return SPV_SUCCESS;
    342 }
    343 
    344 }  // namespace val
    345 }  // namespace spvtools
    346