Home | History | Annotate | Download | only in val
      1 // Copyright (c) 2017 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 // Validates correctness of composite SPIR-V instructions.
     16 
     17 #include "source/val/validate.h"
     18 
     19 #include "source/diagnostic.h"
     20 #include "source/opcode.h"
     21 #include "source/spirv_target_env.h"
     22 #include "source/val/instruction.h"
     23 #include "source/val/validation_state.h"
     24 
     25 namespace spvtools {
     26 namespace val {
     27 namespace {
     28 
     29 // Returns the type of the value accessed by OpCompositeExtract or
     30 // OpCompositeInsert instruction. The function traverses the hierarchy of
     31 // nested data structures (structs, arrays, vectors, matrices) as directed by
     32 // the sequence of indices in the instruction. May return error if traversal
     33 // fails (encountered non-composite, out of bounds, nesting too deep).
     34 // Returns the type of Composite operand if the instruction has no indices.
     35 spv_result_t GetExtractInsertValueType(ValidationState_t& _,
     36                                        const Instruction* inst,
     37                                        uint32_t* member_type) {
     38   const SpvOp opcode = inst->opcode();
     39   assert(opcode == SpvOpCompositeExtract || opcode == SpvOpCompositeInsert);
     40   uint32_t word_index = opcode == SpvOpCompositeExtract ? 4 : 5;
     41   const uint32_t num_words = static_cast<uint32_t>(inst->words().size());
     42   const uint32_t composite_id_index = word_index - 1;
     43 
     44   const uint32_t num_indices = num_words - word_index;
     45   const uint32_t kCompositeExtractInsertMaxNumIndices = 255;
     46   if (num_indices > kCompositeExtractInsertMaxNumIndices) {
     47     return _.diag(SPV_ERROR_INVALID_DATA, inst)
     48            << "The number of indexes in Op" << spvOpcodeString(opcode)
     49            << " may not exceed " << kCompositeExtractInsertMaxNumIndices
     50            << ". Found " << num_indices << " indexes.";
     51   }
     52 
     53   *member_type = _.GetTypeId(inst->word(composite_id_index));
     54   if (*member_type == 0) {
     55     return _.diag(SPV_ERROR_INVALID_DATA, inst)
     56            << "Expected Composite to be an object of composite type";
     57   }
     58 
     59   for (; word_index < num_words; ++word_index) {
     60     const uint32_t component_index = inst->word(word_index);
     61     const Instruction* const type_inst = _.FindDef(*member_type);
     62     assert(type_inst);
     63     switch (type_inst->opcode()) {
     64       case SpvOpTypeVector: {
     65         *member_type = type_inst->word(2);
     66         const uint32_t vector_size = type_inst->word(3);
     67         if (component_index >= vector_size) {
     68           return _.diag(SPV_ERROR_INVALID_DATA, inst)
     69                  << "Vector access is out of bounds, vector size is "
     70                  << vector_size << ", but access index is " << component_index;
     71         }
     72         break;
     73       }
     74       case SpvOpTypeMatrix: {
     75         *member_type = type_inst->word(2);
     76         const uint32_t num_cols = type_inst->word(3);
     77         if (component_index >= num_cols) {
     78           return _.diag(SPV_ERROR_INVALID_DATA, inst)
     79                  << "Matrix access is out of bounds, matrix has " << num_cols
     80                  << " columns, but access index is " << component_index;
     81         }
     82         break;
     83       }
     84       case SpvOpTypeArray: {
     85         uint64_t array_size = 0;
     86         auto size = _.FindDef(type_inst->word(3));
     87         *member_type = type_inst->word(2);
     88         if (spvOpcodeIsSpecConstant(size->opcode())) {
     89           // Cannot verify against the size of this array.
     90           break;
     91         }
     92 
     93         if (!_.GetConstantValUint64(type_inst->word(3), &array_size)) {
     94           assert(0 && "Array type definition is corrupt");
     95         }
     96         if (component_index >= array_size) {
     97           return _.diag(SPV_ERROR_INVALID_DATA, inst)
     98                  << "Array access is out of bounds, array size is "
     99                  << array_size << ", but access index is " << component_index;
    100         }
    101         break;
    102       }
    103       case SpvOpTypeRuntimeArray: {
    104         *member_type = type_inst->word(2);
    105         // Array size is unknown.
    106         break;
    107       }
    108       case SpvOpTypeStruct: {
    109         const size_t num_struct_members = type_inst->words().size() - 2;
    110         if (component_index >= num_struct_members) {
    111           return _.diag(SPV_ERROR_INVALID_DATA, inst)
    112                  << "Index is out of bounds, can not find index "
    113                  << component_index << " in the structure <id> '"
    114                  << type_inst->id() << "'. This structure has "
    115                  << num_struct_members << " members. Largest valid index is "
    116                  << num_struct_members - 1 << ".";
    117         }
    118         *member_type = type_inst->word(component_index + 2);
    119         break;
    120       }
    121       default:
    122         return _.diag(SPV_ERROR_INVALID_DATA, inst)
    123                << "Reached non-composite type while indexes still remain to "
    124                   "be traversed.";
    125     }
    126   }
    127 
    128   return SPV_SUCCESS;
    129 }
    130 
    131 spv_result_t ValidateVectorExtractDynamic(ValidationState_t& _,
    132                                           const Instruction* inst) {
    133   const uint32_t result_type = inst->type_id();
    134   const SpvOp result_opcode = _.GetIdOpcode(result_type);
    135   if (!spvOpcodeIsScalarType(result_opcode)) {
    136     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    137            << "Expected Result Type to be a scalar type";
    138   }
    139 
    140   const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
    141   const SpvOp vector_opcode = _.GetIdOpcode(vector_type);
    142   if (vector_opcode != SpvOpTypeVector) {
    143     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    144            << "Expected Vector type to be OpTypeVector";
    145   }
    146 
    147   if (_.GetComponentType(vector_type) != result_type) {
    148     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    149            << "Expected Vector component type to be equal to Result Type";
    150   }
    151 
    152   const auto index = _.FindDef(inst->GetOperandAs<uint32_t>(3));
    153   if (!index || index->type_id() == 0 || !_.IsIntScalarType(index->type_id())) {
    154     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    155            << "Expected Index to be int scalar";
    156   }
    157   return SPV_SUCCESS;
    158 }
    159 
    160 spv_result_t ValidateVectorInsertDyanmic(ValidationState_t& _,
    161                                          const Instruction* inst) {
    162   const uint32_t result_type = inst->type_id();
    163   const SpvOp result_opcode = _.GetIdOpcode(result_type);
    164   if (result_opcode != SpvOpTypeVector) {
    165     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    166            << "Expected Result Type to be OpTypeVector";
    167   }
    168 
    169   const uint32_t vector_type = _.GetOperandTypeId(inst, 2);
    170   if (vector_type != result_type) {
    171     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    172            << "Expected Vector type to be equal to Result Type";
    173   }
    174 
    175   const uint32_t component_type = _.GetOperandTypeId(inst, 3);
    176   if (_.GetComponentType(result_type) != component_type) {
    177     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    178            << "Expected Component type to be equal to Result Type "
    179            << "component type";
    180   }
    181 
    182   const uint32_t index_type = _.GetOperandTypeId(inst, 4);
    183   if (!_.IsIntScalarType(index_type)) {
    184     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    185            << "Expected Index to be int scalar";
    186   }
    187   return SPV_SUCCESS;
    188 }
    189 
    190 spv_result_t ValidateCompositeConstruct(ValidationState_t& _,
    191                                         const Instruction* inst) {
    192   const uint32_t num_operands = static_cast<uint32_t>(inst->operands().size());
    193   const uint32_t result_type = inst->type_id();
    194   const SpvOp result_opcode = _.GetIdOpcode(result_type);
    195   switch (result_opcode) {
    196     case SpvOpTypeVector: {
    197       const uint32_t num_result_components = _.GetDimension(result_type);
    198       const uint32_t result_component_type = _.GetComponentType(result_type);
    199       uint32_t given_component_count = 0;
    200 
    201       if (num_operands <= 3) {
    202         return _.diag(SPV_ERROR_INVALID_DATA, inst)
    203                << "Expected number of constituents to be at least 2";
    204       }
    205 
    206       for (uint32_t operand_index = 2; operand_index < num_operands;
    207            ++operand_index) {
    208         const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
    209         if (operand_type == result_component_type) {
    210           ++given_component_count;
    211         } else {
    212           if (_.GetIdOpcode(operand_type) != SpvOpTypeVector ||
    213               _.GetComponentType(operand_type) != result_component_type) {
    214             return _.diag(SPV_ERROR_INVALID_DATA, inst)
    215                    << "Expected Constituents to be scalars or vectors of"
    216                    << " the same type as Result Type components";
    217           }
    218 
    219           given_component_count += _.GetDimension(operand_type);
    220         }
    221       }
    222 
    223       if (num_result_components != given_component_count) {
    224         return _.diag(SPV_ERROR_INVALID_DATA, inst)
    225                << "Expected total number of given components to be equal "
    226                << "to the size of Result Type vector";
    227       }
    228 
    229       break;
    230     }
    231     case SpvOpTypeMatrix: {
    232       uint32_t result_num_rows = 0;
    233       uint32_t result_num_cols = 0;
    234       uint32_t result_col_type = 0;
    235       uint32_t result_component_type = 0;
    236       if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols,
    237                                &result_col_type, &result_component_type)) {
    238         assert(0);
    239       }
    240 
    241       if (result_num_cols + 2 != num_operands) {
    242         return _.diag(SPV_ERROR_INVALID_DATA, inst)
    243                << "Expected total number of Constituents to be equal "
    244                << "to the number of columns of Result Type matrix";
    245       }
    246 
    247       for (uint32_t operand_index = 2; operand_index < num_operands;
    248            ++operand_index) {
    249         const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
    250         if (operand_type != result_col_type) {
    251           return _.diag(SPV_ERROR_INVALID_DATA, inst)
    252                  << "Expected Constituent type to be equal to the column "
    253                  << "type Result Type matrix";
    254         }
    255       }
    256 
    257       break;
    258     }
    259     case SpvOpTypeArray: {
    260       const Instruction* const array_inst = _.FindDef(result_type);
    261       assert(array_inst);
    262       assert(array_inst->opcode() == SpvOpTypeArray);
    263 
    264       auto size = _.FindDef(array_inst->word(3));
    265       if (spvOpcodeIsSpecConstant(size->opcode())) {
    266         // Cannot verify against the size of this array.
    267         break;
    268       }
    269 
    270       uint64_t array_size = 0;
    271       if (!_.GetConstantValUint64(array_inst->word(3), &array_size)) {
    272         assert(0 && "Array type definition is corrupt");
    273       }
    274 
    275       if (array_size + 2 != num_operands) {
    276         return _.diag(SPV_ERROR_INVALID_DATA, inst)
    277                << "Expected total number of Constituents to be equal "
    278                << "to the number of elements of Result Type array";
    279       }
    280 
    281       const uint32_t result_component_type = array_inst->word(2);
    282       for (uint32_t operand_index = 2; operand_index < num_operands;
    283            ++operand_index) {
    284         const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
    285         if (operand_type != result_component_type) {
    286           return _.diag(SPV_ERROR_INVALID_DATA, inst)
    287                  << "Expected Constituent type to be equal to the column "
    288                  << "type Result Type array";
    289         }
    290       }
    291 
    292       break;
    293     }
    294     case SpvOpTypeStruct: {
    295       const Instruction* const struct_inst = _.FindDef(result_type);
    296       assert(struct_inst);
    297       assert(struct_inst->opcode() == SpvOpTypeStruct);
    298 
    299       if (struct_inst->operands().size() + 1 != num_operands) {
    300         return _.diag(SPV_ERROR_INVALID_DATA, inst)
    301                << "Expected total number of Constituents to be equal "
    302                << "to the number of members of Result Type struct";
    303       }
    304 
    305       for (uint32_t operand_index = 2; operand_index < num_operands;
    306            ++operand_index) {
    307         const uint32_t operand_type = _.GetOperandTypeId(inst, operand_index);
    308         const uint32_t member_type = struct_inst->word(operand_index);
    309         if (operand_type != member_type) {
    310           return _.diag(SPV_ERROR_INVALID_DATA, inst)
    311                  << "Expected Constituent type to be equal to the "
    312                  << "corresponding member type of Result Type struct";
    313         }
    314       }
    315 
    316       break;
    317     }
    318     default: {
    319       return _.diag(SPV_ERROR_INVALID_DATA, inst)
    320              << "Expected Result Type to be a composite type";
    321     }
    322   }
    323   return SPV_SUCCESS;
    324 }
    325 
    326 spv_result_t ValidateCompositeExtract(ValidationState_t& _,
    327                                       const Instruction* inst) {
    328   uint32_t member_type = 0;
    329   if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
    330     return error;
    331   }
    332 
    333   const uint32_t result_type = inst->type_id();
    334   if (result_type != member_type) {
    335     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    336            << "Result type (Op" << spvOpcodeString(_.GetIdOpcode(result_type))
    337            << ") does not match the type that results from indexing into "
    338               "the composite (Op"
    339            << spvOpcodeString(_.GetIdOpcode(member_type)) << ").";
    340   }
    341   return SPV_SUCCESS;
    342 }
    343 
    344 spv_result_t ValidateCompositeInsert(ValidationState_t& _,
    345                                      const Instruction* inst) {
    346   const SpvOp opcode = inst->opcode();
    347   const uint32_t object_type = _.GetOperandTypeId(inst, 2);
    348   const uint32_t composite_type = _.GetOperandTypeId(inst, 3);
    349   const uint32_t result_type = inst->type_id();
    350   if (result_type != composite_type) {
    351     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    352            << "The Result Type must be the same as Composite type in Op"
    353            << spvOpcodeString(opcode) << " yielding Result Id " << result_type
    354            << ".";
    355   }
    356 
    357   uint32_t member_type = 0;
    358   if (spv_result_t error = GetExtractInsertValueType(_, inst, &member_type)) {
    359     return error;
    360   }
    361 
    362   if (object_type != member_type) {
    363     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    364            << "The Object type (Op"
    365            << spvOpcodeString(_.GetIdOpcode(object_type))
    366            << ") does not match the type that results from indexing into the "
    367               "Composite (Op"
    368            << spvOpcodeString(_.GetIdOpcode(member_type)) << ").";
    369   }
    370   return SPV_SUCCESS;
    371 }
    372 
    373 spv_result_t ValidateCopyObject(ValidationState_t& _, const Instruction* inst) {
    374   const uint32_t result_type = inst->type_id();
    375   const uint32_t operand_type = _.GetOperandTypeId(inst, 2);
    376   if (operand_type != result_type) {
    377     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    378            << "Expected Result Type and Operand type to be the same";
    379   }
    380   return SPV_SUCCESS;
    381 }
    382 
    383 spv_result_t ValidateTranspose(ValidationState_t& _, const Instruction* inst) {
    384   uint32_t result_num_rows = 0;
    385   uint32_t result_num_cols = 0;
    386   uint32_t result_col_type = 0;
    387   uint32_t result_component_type = 0;
    388   const uint32_t result_type = inst->type_id();
    389   if (!_.GetMatrixTypeInfo(result_type, &result_num_rows, &result_num_cols,
    390                            &result_col_type, &result_component_type)) {
    391     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    392            << "Expected Result Type to be a matrix type";
    393   }
    394 
    395   const uint32_t matrix_type = _.GetOperandTypeId(inst, 2);
    396   uint32_t matrix_num_rows = 0;
    397   uint32_t matrix_num_cols = 0;
    398   uint32_t matrix_col_type = 0;
    399   uint32_t matrix_component_type = 0;
    400   if (!_.GetMatrixTypeInfo(matrix_type, &matrix_num_rows, &matrix_num_cols,
    401                            &matrix_col_type, &matrix_component_type)) {
    402     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    403            << "Expected Matrix to be of type OpTypeMatrix";
    404   }
    405 
    406   if (result_component_type != matrix_component_type) {
    407     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    408            << "Expected component types of Matrix and Result Type to be "
    409            << "identical";
    410   }
    411 
    412   if (result_num_rows != matrix_num_cols ||
    413       result_num_cols != matrix_num_rows) {
    414     return _.diag(SPV_ERROR_INVALID_DATA, inst)
    415            << "Expected number of columns and the column size of Matrix "
    416            << "to be the reverse of those of Result Type";
    417   }
    418   return SPV_SUCCESS;
    419 }
    420 
    421 spv_result_t ValidateVectorShuffle(ValidationState_t& _,
    422                                    const Instruction* inst) {
    423   auto resultType = _.FindDef(inst->type_id());
    424   if (!resultType || resultType->opcode() != SpvOpTypeVector) {
    425     return _.diag(SPV_ERROR_INVALID_ID, inst)
    426            << "The Result Type of OpVectorShuffle must be"
    427            << " OpTypeVector. Found Op"
    428            << spvOpcodeString(static_cast<SpvOp>(resultType->opcode())) << ".";
    429   }
    430 
    431   // The number of components in Result Type must be the same as the number of
    432   // Component operands.
    433   auto componentCount = inst->operands().size() - 4;
    434   auto resultVectorDimension = resultType->GetOperandAs<uint32_t>(2);
    435   if (componentCount != resultVectorDimension) {
    436     return _.diag(SPV_ERROR_INVALID_ID, inst)
    437            << "OpVectorShuffle component literals count does not match "
    438               "Result Type <id> '"
    439            << _.getIdName(resultType->id()) << "'s vector component count.";
    440   }
    441 
    442   // Vector 1 and Vector 2 must both have vector types, with the same Component
    443   // Type as Result Type.
    444   auto vector1Object = _.FindDef(inst->GetOperandAs<uint32_t>(2));
    445   auto vector1Type = _.FindDef(vector1Object->type_id());
    446   auto vector2Object = _.FindDef(inst->GetOperandAs<uint32_t>(3));
    447   auto vector2Type = _.FindDef(vector2Object->type_id());
    448   if (!vector1Type || vector1Type->opcode() != SpvOpTypeVector) {
    449     return _.diag(SPV_ERROR_INVALID_ID, inst)
    450            << "The type of Vector 1 must be OpTypeVector.";
    451   }
    452   if (!vector2Type || vector2Type->opcode() != SpvOpTypeVector) {
    453     return _.diag(SPV_ERROR_INVALID_ID, inst)
    454            << "The type of Vector 2 must be OpTypeVector.";
    455   }
    456 
    457   auto resultComponentType = resultType->GetOperandAs<uint32_t>(1);
    458   if (vector1Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
    459     return _.diag(SPV_ERROR_INVALID_ID, inst)
    460            << "The Component Type of Vector 1 must be the same as ResultType.";
    461   }
    462   if (vector2Type->GetOperandAs<uint32_t>(1) != resultComponentType) {
    463     return _.diag(SPV_ERROR_INVALID_ID, inst)
    464            << "The Component Type of Vector 2 must be the same as ResultType.";
    465   }
    466 
    467   // All Component literals must either be FFFFFFFF or in [0, N - 1].
    468   // For WebGPU specifically, Component literals cannot be FFFFFFFF.
    469   auto vector1ComponentCount = vector1Type->GetOperandAs<uint32_t>(2);
    470   auto vector2ComponentCount = vector2Type->GetOperandAs<uint32_t>(2);
    471   auto N = vector1ComponentCount + vector2ComponentCount;
    472   auto firstLiteralIndex = 4;
    473   const auto is_webgpu_env = spvIsWebGPUEnv(_.context()->target_env);
    474   for (size_t i = firstLiteralIndex; i < inst->operands().size(); ++i) {
    475     auto literal = inst->GetOperandAs<uint32_t>(i);
    476     if (literal != 0xFFFFFFFF && literal >= N) {
    477       return _.diag(SPV_ERROR_INVALID_ID, inst)
    478              << "Component index " << literal << " is out of bounds for "
    479              << "combined (Vector1 + Vector2) size of " << N << ".";
    480     }
    481 
    482     if (is_webgpu_env && literal == 0xFFFFFFFF) {
    483       return _.diag(SPV_ERROR_INVALID_ID, inst)
    484              << "Component literal at operand " << i - firstLiteralIndex
    485              << " cannot be 0xFFFFFFFF in WebGPU execution environment.";
    486     }
    487   }
    488 
    489   return SPV_SUCCESS;
    490 }
    491 
    492 }  // anonymous namespace
    493 
    494 // Validates correctness of composite instructions.
    495 spv_result_t CompositesPass(ValidationState_t& _, const Instruction* inst) {
    496   switch (inst->opcode()) {
    497     case SpvOpVectorExtractDynamic:
    498       return ValidateVectorExtractDynamic(_, inst);
    499     case SpvOpVectorInsertDynamic:
    500       return ValidateVectorInsertDyanmic(_, inst);
    501     case SpvOpVectorShuffle:
    502       return ValidateVectorShuffle(_, inst);
    503     case SpvOpCompositeConstruct:
    504       return ValidateCompositeConstruct(_, inst);
    505     case SpvOpCompositeExtract:
    506       return ValidateCompositeExtract(_, inst);
    507     case SpvOpCompositeInsert:
    508       return ValidateCompositeInsert(_, inst);
    509     case SpvOpCopyObject:
    510       return ValidateCopyObject(_, inst);
    511     case SpvOpTranspose:
    512       return ValidateTranspose(_, inst);
    513     default:
    514       break;
    515   }
    516 
    517   return SPV_SUCCESS;
    518 }
    519 
    520 }  // namespace val
    521 }  // namespace spvtools
    522