Home | History | Annotate | Download | only in comp
      1 // Copyright (c) 2018 Google LLC
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/comp/markv_decoder.h"
     16 
     17 #include <cstring>
     18 #include <iterator>
     19 #include <numeric>
     20 
     21 #include "source/ext_inst.h"
     22 #include "source/opcode.h"
     23 #include "spirv-tools/libspirv.hpp"
     24 
     25 namespace spvtools {
     26 namespace comp {
     27 
     28 spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) {
     29   auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_);
     30 
     31   if (codec) {
     32     uint64_t decoded_value = 0;
     33     if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
     34       return Diag(SPV_ERROR_INVALID_BINARY)
     35              << "Failed to decode non-id word with Huffman";
     36 
     37     if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
     38       // The word decoded successfully.
     39       *word = uint32_t(decoded_value);
     40       assert(*word == decoded_value);
     41       return SPV_SUCCESS;
     42     }
     43 
     44     // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
     45   }
     46 
     47   const size_t chunk_length =
     48       model_->GetOperandVariableWidthChunkLength(operand_.type);
     49   if (chunk_length) {
     50     if (!reader_.ReadVariableWidthU32(word, chunk_length))
     51       return Diag(SPV_ERROR_INVALID_BINARY)
     52              << "Failed to decode non-id word with varint";
     53   } else {
     54     if (!reader_.ReadUnencoded(word))
     55       return Diag(SPV_ERROR_INVALID_BINARY)
     56              << "Failed to read unencoded non-id word";
     57   }
     58   return SPV_SUCCESS;
     59 }
     60 
     61 spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands(
     62     uint32_t* opcode, uint32_t* num_operands) {
     63   // First try to use the Markov chain codec.
     64   auto* codec =
     65       model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode());
     66   if (codec) {
     67     uint64_t decoded_value = 0;
     68     if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
     69       return Diag(SPV_ERROR_INTERNAL)
     70              << "Failed to decode opcode_and_num_operands, previous opcode is "
     71              << spvOpcodeString(GetPrevOpcode());
     72 
     73     if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
     74       // The word was successfully decoded.
     75       *opcode = uint32_t(decoded_value & 0xFFFF);
     76       *num_operands = uint32_t(decoded_value >> 16);
     77       return SPV_SUCCESS;
     78     }
     79 
     80     // Received kMarkvNoneOfTheAbove signal, use fallback decoding.
     81   }
     82 
     83   // Fallback to base-rate codec.
     84   codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop);
     85   assert(codec);
     86   uint64_t decoded_value = 0;
     87   if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
     88     return Diag(SPV_ERROR_INTERNAL)
     89            << "Failed to decode opcode_and_num_operands with global codec";
     90 
     91   if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) {
     92     // Received kMarkvNoneOfTheAbove signal, fallback further.
     93     return SPV_UNSUPPORTED;
     94   }
     95 
     96   *opcode = uint32_t(decoded_value & 0xFFFF);
     97   *num_operands = uint32_t(decoded_value >> 16);
     98   return SPV_SUCCESS;
     99 }
    100 
    101 spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf,
    102                                                 uint32_t fallback_method,
    103                                                 uint32_t* rank) {
    104   const auto* codec = GetMtfHuffmanCodec(mtf);
    105   if (!codec) {
    106     assert(fallback_method != kMtfNone);
    107     codec = GetMtfHuffmanCodec(fallback_method);
    108   }
    109 
    110   if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank";
    111 
    112   uint32_t decoded_value = 0;
    113   if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
    114     return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman";
    115 
    116   if (decoded_value == kMtfRankEncodedByValueSignal) {
    117     // Decode by value.
    118     if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length()))
    119       return Diag(SPV_ERROR_INTERNAL)
    120              << "Failed to decode MTF rank with varint";
    121     *rank += MarkvCodec::kMtfSmallestRankEncodedByValue;
    122   } else {
    123     // Decode using Huffman coding.
    124     assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue);
    125     *rank = decoded_value;
    126   }
    127   return SPV_SUCCESS;
    128 }
    129 
    130 spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) {
    131   auto* codec =
    132       model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_);
    133 
    134   uint64_t mtf = kMtfNone;
    135   if (codec) {
    136     uint64_t decoded_value = 0;
    137     if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value))
    138       return Diag(SPV_ERROR_INTERNAL)
    139              << "Failed to decode descriptor with Huffman";
    140 
    141     if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) {
    142       const uint32_t long_descriptor = uint32_t(decoded_value);
    143       mtf = GetMtfLongIdDescriptor(long_descriptor);
    144     }
    145   }
    146 
    147   if (mtf == kMtfNone) {
    148     if (model_->id_fallback_strategy() !=
    149         MarkvModel::IdFallbackStrategy::kShortDescriptor) {
    150       return SPV_UNSUPPORTED;
    151     }
    152 
    153     uint64_t decoded_value = 0;
    154     if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits))
    155       return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor";
    156     const uint32_t short_descriptor = uint32_t(decoded_value);
    157     if (short_descriptor == 0) {
    158       // Forward declared id.
    159       return SPV_UNSUPPORTED;
    160     }
    161     mtf = GetMtfShortIdDescriptor(short_descriptor);
    162   }
    163 
    164   return DecodeExistingId(mtf, id);
    165 }
    166 
    167 spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) {
    168   assert(multi_mtf_.GetSize(mtf) > 0);
    169   *id = 0;
    170 
    171   uint32_t rank = 0;
    172 
    173   if (multi_mtf_.GetSize(mtf) == 1) {
    174     rank = 1;
    175   } else {
    176     const spv_result_t result =
    177         DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank);
    178     if (result != SPV_SUCCESS) return result;
    179   }
    180 
    181   assert(rank);
    182   if (!multi_mtf_.ValueFromRank(mtf, rank, id))
    183     return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds";
    184 
    185   return SPV_SUCCESS;
    186 }
    187 
    188 spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) {
    189   {
    190     const spv_result_t result = DecodeIdWithDescriptor(id);
    191     if (result != SPV_UNSUPPORTED) return result;
    192   }
    193 
    194   const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction(
    195       SpvOp(inst_.opcode))(operand_index_);
    196   uint32_t rank = 0;
    197   *id = 0;
    198 
    199   if (model_->id_fallback_strategy() ==
    200       MarkvModel::IdFallbackStrategy::kRuleBased) {
    201     uint64_t mtf = GetRuleBasedMtf();
    202     if (mtf != kMtfNone && !can_forward_declare) {
    203       return DecodeExistingId(mtf, id);
    204     }
    205 
    206     if (mtf == kMtfNone) mtf = kMtfAll;
    207     {
    208       const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank);
    209       if (result != SPV_SUCCESS) return result;
    210     }
    211 
    212     if (rank == 0) {
    213       // This is the first occurrence of a forward declared id.
    214       *id = GetIdBound();
    215       SetIdBound(*id + 1);
    216       multi_mtf_.Insert(kMtfAll, *id);
    217       multi_mtf_.Insert(kMtfForwardDeclared, *id);
    218       if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id);
    219     } else {
    220       if (!multi_mtf_.ValueFromRank(mtf, rank, id))
    221         return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
    222     }
    223   } else {
    224     assert(can_forward_declare);
    225 
    226     if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
    227       return Diag(SPV_ERROR_INTERNAL)
    228              << "Failed to decode MTF rank with varint";
    229 
    230     if (rank == 0) {
    231       // This is the first occurrence of a forward declared id.
    232       *id = GetIdBound();
    233       SetIdBound(*id + 1);
    234       multi_mtf_.Insert(kMtfForwardDeclared, *id);
    235     } else {
    236       if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id))
    237         return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds";
    238     }
    239   }
    240   assert(*id);
    241   return SPV_SUCCESS;
    242 }
    243 
    244 spv_result_t MarkvDecoder::DecodeTypeId() {
    245   if (inst_.opcode == SpvOpFunctionParameter) {
    246     assert(!remaining_function_parameter_types_.empty());
    247     inst_.type_id = remaining_function_parameter_types_.front();
    248     remaining_function_parameter_types_.pop_front();
    249     return SPV_SUCCESS;
    250   }
    251 
    252   {
    253     const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id);
    254     if (result != SPV_UNSUPPORTED) return result;
    255   }
    256 
    257   assert(model_->id_fallback_strategy() ==
    258          MarkvModel::IdFallbackStrategy::kRuleBased);
    259 
    260   uint64_t mtf = GetRuleBasedMtf();
    261   assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))(
    262       operand_index_));
    263 
    264   if (mtf == kMtfNone) {
    265     mtf = kMtfTypeNonFunction;
    266     // Function types should have been handled by GetRuleBasedMtf.
    267     assert(inst_.opcode != SpvOpFunction);
    268   }
    269 
    270   return DecodeExistingId(mtf, &inst_.type_id);
    271 }
    272 
    273 spv_result_t MarkvDecoder::DecodeResultId() {
    274   uint32_t rank = 0;
    275 
    276   const uint64_t num_still_forward_declared =
    277       multi_mtf_.GetSize(kMtfForwardDeclared);
    278 
    279   if (num_still_forward_declared) {
    280     // Some ids were forward declared. Check if this id is one of them.
    281     uint64_t id_was_forward_declared;
    282     if (!reader_.ReadBits(&id_was_forward_declared, 1))
    283       return Diag(SPV_ERROR_INVALID_BINARY)
    284              << "Failed to read id_was_forward_declared flag";
    285 
    286     if (id_was_forward_declared) {
    287       if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length()))
    288         return Diag(SPV_ERROR_INVALID_BINARY)
    289                << "Failed to read MTF rank of forward declared id";
    290 
    291       if (rank) {
    292         // The id was forward declared, recover it from kMtfForwardDeclared.
    293         if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank,
    294                                       &inst_.result_id))
    295           return Diag(SPV_ERROR_INTERNAL)
    296                  << "Forward declared MTF rank is out of bounds";
    297 
    298         // We can now remove the id from kMtfForwardDeclared.
    299         if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id))
    300           return Diag(SPV_ERROR_INTERNAL)
    301                  << "Failed to remove id from kMtfForwardDeclared";
    302       }
    303     }
    304   }
    305 
    306   if (inst_.result_id == 0) {
    307     // The id was not forward declared, issue a new id.
    308     inst_.result_id = GetIdBound();
    309     SetIdBound(inst_.result_id + 1);
    310   }
    311 
    312   if (model_->id_fallback_strategy() ==
    313       MarkvModel::IdFallbackStrategy::kRuleBased) {
    314     if (!rank) {
    315       multi_mtf_.Insert(kMtfAll, inst_.result_id);
    316     }
    317   }
    318 
    319   return SPV_SUCCESS;
    320 }
    321 
    322 spv_result_t MarkvDecoder::DecodeLiteralNumber(
    323     const spv_parsed_operand_t& operand) {
    324   if (operand.number_bit_width <= 32) {
    325     uint32_t word = 0;
    326     const spv_result_t result = DecodeNonIdWord(&word);
    327     if (result != SPV_SUCCESS) return result;
    328     inst_words_.push_back(word);
    329   } else {
    330     assert(operand.number_bit_width <= 64);
    331     uint64_t word = 0;
    332     if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) {
    333       if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length()))
    334         return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64";
    335     } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) {
    336       int64_t val = 0;
    337       if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(),
    338                                         model_->s64_block_exponent()))
    339         return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64";
    340       std::memcpy(&word, &val, 8);
    341     } else if (operand.number_kind == SPV_NUMBER_FLOATING) {
    342       if (!reader_.ReadUnencoded(&word))
    343         return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64";
    344     } else {
    345       return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length";
    346     }
    347     inst_words_.push_back(static_cast<uint32_t>(word));
    348     inst_words_.push_back(static_cast<uint32_t>(word >> 32));
    349   }
    350   return SPV_SUCCESS;
    351 }
    352 
    353 bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) {
    354   const size_t num_bits_to_next_byte =
    355       GetNumBitsToNextByte(reader_.GetNumReadBits());
    356   if (num_bits_to_next_byte == 0 ||
    357       num_bits_to_next_byte > byte_break_if_less_than)
    358     return true;
    359 
    360   uint64_t bits = 0;
    361   if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false;
    362 
    363   assert(bits == 0);
    364   if (bits != 0) return false;
    365 
    366   return true;
    367 }
    368 
    369 spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) {
    370   const bool header_read_success =
    371       reader_.ReadUnencoded(&header_.magic_number) &&
    372       reader_.ReadUnencoded(&header_.markv_version) &&
    373       reader_.ReadUnencoded(&header_.markv_model) &&
    374       reader_.ReadUnencoded(&header_.markv_length_in_bits) &&
    375       reader_.ReadUnencoded(&header_.spirv_version) &&
    376       reader_.ReadUnencoded(&header_.spirv_generator);
    377 
    378   if (!header_read_success)
    379     return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header";
    380 
    381   if (header_.markv_length_in_bits == 0)
    382     return Diag(SPV_ERROR_INVALID_BINARY)
    383            << "Header markv_length_in_bits field is zero";
    384 
    385   if (header_.magic_number != MarkvCodec::kMarkvMagicNumber)
    386     return Diag(SPV_ERROR_INVALID_BINARY)
    387            << "MARK-V binary has incorrect magic number";
    388 
    389   // TODO(atgoo (at) github.com): Print version strings.
    390   if (header_.markv_version != MarkvCodec::GetMarkvVersion())
    391     return Diag(SPV_ERROR_INVALID_BINARY)
    392            << "MARK-V binary and the codec have different versions";
    393 
    394   const uint32_t model_type = header_.markv_model >> 16;
    395   const uint32_t model_version = header_.markv_model & 0xFFFF;
    396   if (model_type != model_->model_type())
    397     return Diag(SPV_ERROR_INVALID_BINARY)
    398            << "MARK-V binary and the codec use different MARK-V models";
    399 
    400   if (model_version != model_->model_version())
    401     return Diag(SPV_ERROR_INVALID_BINARY)
    402            << "MARK-V binary and the codec use different versions if the same "
    403            << "MARK-V model";
    404 
    405   spirv_.reserve(header_.markv_length_in_bits / 2);  // Heuristic.
    406   spirv_.resize(5, 0);
    407   spirv_[0] = SpvMagicNumber;
    408   spirv_[1] = header_.spirv_version;
    409   spirv_[2] = header_.spirv_generator;
    410 
    411   if (logger_) {
    412     reader_.SetCallback(
    413         [this](const std::string& str) { logger_->AppendBitSequence(str); });
    414   }
    415 
    416   while (reader_.GetNumReadBits() < header_.markv_length_in_bits) {
    417     inst_ = {};
    418     const spv_result_t decode_result = DecodeInstruction();
    419     if (decode_result != SPV_SUCCESS) return decode_result;
    420   }
    421 
    422   if (validator_options_) {
    423     spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()};
    424     const spv_result_t result = spvValidateWithOptions(
    425         context_, validator_options_, &validation_binary, nullptr);
    426     if (result != SPV_SUCCESS) return result;
    427   }
    428 
    429   // Validate the decode binary
    430   if (reader_.GetNumReadBits() != header_.markv_length_in_bits ||
    431       !reader_.OnlyZeroesLeft()) {
    432     return Diag(SPV_ERROR_INVALID_BINARY)
    433            << "MARK-V binary has wrong stated bit length "
    434            << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits;
    435   }
    436 
    437   // Decoding of the module is finished, validation state should have correct
    438   // id bound.
    439   spirv_[3] = GetIdBound();
    440 
    441   *spirv_binary = std::move(spirv_);
    442   return SPV_SUCCESS;
    443 }
    444 
    445 // TODO(atgoo (at) github.com): The implementation borrows heavily from
    446 // Parser::parseOperand.
    447 // Consider coupling them together in some way once MARK-V codec is more mature.
    448 // For now it's better to keep the code independent for experimentation
    449 // purposes.
    450 spv_result_t MarkvDecoder::DecodeOperand(
    451     size_t operand_offset, const spv_operand_type_t type,
    452     spv_operand_pattern_t* expected_operands) {
    453   const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
    454 
    455   memset(&operand_, 0, sizeof(operand_));
    456 
    457   assert((operand_offset >> 16) == 0);
    458   operand_.offset = static_cast<uint16_t>(operand_offset);
    459   operand_.type = type;
    460 
    461   // Set default values, may be updated later.
    462   operand_.number_kind = SPV_NUMBER_NONE;
    463   operand_.number_bit_width = 0;
    464 
    465   const size_t first_word_index = inst_words_.size();
    466 
    467   switch (type) {
    468     case SPV_OPERAND_TYPE_RESULT_ID: {
    469       const spv_result_t result = DecodeResultId();
    470       if (result != SPV_SUCCESS) return result;
    471 
    472       inst_words_.push_back(inst_.result_id);
    473       SetIdBound(std::max(GetIdBound(), inst_.result_id + 1));
    474       PromoteIfNeeded(inst_.result_id);
    475       break;
    476     }
    477 
    478     case SPV_OPERAND_TYPE_TYPE_ID: {
    479       const spv_result_t result = DecodeTypeId();
    480       if (result != SPV_SUCCESS) return result;
    481 
    482       inst_words_.push_back(inst_.type_id);
    483       SetIdBound(std::max(GetIdBound(), inst_.type_id + 1));
    484       PromoteIfNeeded(inst_.type_id);
    485       break;
    486     }
    487 
    488     case SPV_OPERAND_TYPE_ID:
    489     case SPV_OPERAND_TYPE_OPTIONAL_ID:
    490     case SPV_OPERAND_TYPE_SCOPE_ID:
    491     case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: {
    492       uint32_t id = 0;
    493       const spv_result_t result = DecodeRefId(&id);
    494       if (result != SPV_SUCCESS) return result;
    495 
    496       if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0";
    497 
    498       if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) {
    499         operand_.type = SPV_OPERAND_TYPE_ID;
    500 
    501         if (opcode == SpvOpExtInst && operand_.offset == 3) {
    502           // The current word is the extended instruction set id.
    503           // Set the extended instruction set type for the current
    504           // instruction.
    505           auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id);
    506           if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) {
    507             return Diag(SPV_ERROR_INVALID_ID)
    508                    << "OpExtInst set id " << id
    509                    << " does not reference an OpExtInstImport result Id";
    510           }
    511           inst_.ext_inst_type = ext_inst_type_iter->second;
    512         }
    513       }
    514 
    515       inst_words_.push_back(id);
    516       SetIdBound(std::max(GetIdBound(), id + 1));
    517       PromoteIfNeeded(id);
    518       break;
    519     }
    520 
    521     case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: {
    522       uint32_t word = 0;
    523       const spv_result_t result = DecodeNonIdWord(&word);
    524       if (result != SPV_SUCCESS) return result;
    525 
    526       inst_words_.push_back(word);
    527 
    528       assert(SpvOpExtInst == opcode);
    529       assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE);
    530       spv_ext_inst_desc ext_inst;
    531       if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst))
    532         return Diag(SPV_ERROR_INVALID_BINARY)
    533                << "Invalid extended instruction number: " << word;
    534       spvPushOperandTypes(ext_inst->operandTypes, expected_operands);
    535       break;
    536     }
    537 
    538     case SPV_OPERAND_TYPE_LITERAL_INTEGER:
    539     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: {
    540       // These are regular single-word literal integer operands.
    541       // Post-parsing validation should check the range of the parsed value.
    542       operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER;
    543       // It turns out they are always unsigned integers!
    544       operand_.number_kind = SPV_NUMBER_UNSIGNED_INT;
    545       operand_.number_bit_width = 32;
    546 
    547       uint32_t word = 0;
    548       const spv_result_t result = DecodeNonIdWord(&word);
    549       if (result != SPV_SUCCESS) return result;
    550 
    551       inst_words_.push_back(word);
    552       break;
    553     }
    554 
    555     case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER:
    556     case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: {
    557       operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER;
    558       if (opcode == SpvOpSwitch) {
    559         // The literal operands have the same type as the value
    560         // referenced by the selector Id.
    561         const uint32_t selector_id = inst_words_.at(1);
    562         const auto type_id_iter = id_to_type_id_.find(selector_id);
    563         if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) {
    564           return Diag(SPV_ERROR_INVALID_BINARY)
    565                  << "Invalid OpSwitch: selector id " << selector_id
    566                  << " has no type";
    567         }
    568         uint32_t type_id = type_id_iter->second;
    569 
    570         if (selector_id == type_id) {
    571           // Recall that by convention, a result ID that is a type definition
    572           // maps to itself.
    573           return Diag(SPV_ERROR_INVALID_BINARY)
    574                  << "Invalid OpSwitch: selector id " << selector_id
    575                  << " is a type, not a value";
    576         }
    577         if (auto error = SetNumericTypeInfoForType(&operand_, type_id))
    578           return error;
    579         if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT &&
    580             operand_.number_kind != SPV_NUMBER_SIGNED_INT) {
    581           return Diag(SPV_ERROR_INVALID_BINARY)
    582                  << "Invalid OpSwitch: selector id " << selector_id
    583                  << " is not a scalar integer";
    584         }
    585       } else {
    586         assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant);
    587         // The literal number type is determined by the type Id for the
    588         // constant.
    589         assert(inst_.type_id);
    590         if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id))
    591           return error;
    592       }
    593 
    594       if (auto error = DecodeLiteralNumber(operand_)) return error;
    595 
    596       break;
    597     }
    598 
    599     case SPV_OPERAND_TYPE_LITERAL_STRING:
    600     case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: {
    601       operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING;
    602       std::vector<char> str;
    603       auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode);
    604 
    605       if (codec) {
    606         std::string decoded_string;
    607         const bool huffman_result =
    608             codec->DecodeFromStream(GetReadBitCallback(), &decoded_string);
    609         assert(huffman_result);
    610         if (!huffman_result)
    611           return Diag(SPV_ERROR_INVALID_BINARY)
    612                  << "Failed to read literal string";
    613 
    614         if (decoded_string != "kMarkvNoneOfTheAbove") {
    615           std::copy(decoded_string.begin(), decoded_string.end(),
    616                     std::back_inserter(str));
    617           str.push_back('\0');
    618         }
    619       }
    620 
    621       // The loop is expected to terminate once we encounter '\0' or exhaust
    622       // the bit stream.
    623       if (str.empty()) {
    624         while (true) {
    625           char ch = 0;
    626           if (!reader_.ReadUnencoded(&ch))
    627             return Diag(SPV_ERROR_INVALID_BINARY)
    628                    << "Failed to read literal string";
    629 
    630           str.push_back(ch);
    631 
    632           if (ch == '\0') break;
    633         }
    634       }
    635 
    636       while (str.size() % 4 != 0) str.push_back('\0');
    637 
    638       inst_words_.resize(inst_words_.size() + str.size() / 4);
    639       std::memcpy(&inst_words_[first_word_index], str.data(), str.size());
    640 
    641       if (SpvOpExtInstImport == opcode) {
    642         // Record the extended instruction type for the ID for this import.
    643         // There is only one string literal argument to OpExtInstImport,
    644         // so it's sufficient to guard this just on the opcode.
    645         const spv_ext_inst_type_t ext_inst_type =
    646             spvExtInstImportTypeGet(str.data());
    647         if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) {
    648           return Diag(SPV_ERROR_INVALID_BINARY)
    649                  << "Invalid extended instruction import '" << str.data()
    650                  << "'";
    651         }
    652         // We must have parsed a valid result ID.  It's a condition
    653         // of the grammar, and we only accept non-zero result Ids.
    654         assert(inst_.result_id);
    655         const bool inserted =
    656             import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type)
    657                 .second;
    658         (void)inserted;
    659         assert(inserted);
    660       }
    661       break;
    662     }
    663 
    664     case SPV_OPERAND_TYPE_CAPABILITY:
    665     case SPV_OPERAND_TYPE_SOURCE_LANGUAGE:
    666     case SPV_OPERAND_TYPE_EXECUTION_MODEL:
    667     case SPV_OPERAND_TYPE_ADDRESSING_MODEL:
    668     case SPV_OPERAND_TYPE_MEMORY_MODEL:
    669     case SPV_OPERAND_TYPE_EXECUTION_MODE:
    670     case SPV_OPERAND_TYPE_STORAGE_CLASS:
    671     case SPV_OPERAND_TYPE_DIMENSIONALITY:
    672     case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE:
    673     case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE:
    674     case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT:
    675     case SPV_OPERAND_TYPE_FP_ROUNDING_MODE:
    676     case SPV_OPERAND_TYPE_LINKAGE_TYPE:
    677     case SPV_OPERAND_TYPE_ACCESS_QUALIFIER:
    678     case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER:
    679     case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE:
    680     case SPV_OPERAND_TYPE_DECORATION:
    681     case SPV_OPERAND_TYPE_BUILT_IN:
    682     case SPV_OPERAND_TYPE_GROUP_OPERATION:
    683     case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS:
    684     case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: {
    685       // A single word that is a plain enum value.
    686       uint32_t word = 0;
    687       const spv_result_t result = DecodeNonIdWord(&word);
    688       if (result != SPV_SUCCESS) return result;
    689 
    690       inst_words_.push_back(word);
    691 
    692       // Map an optional operand type to its corresponding concrete type.
    693       if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER)
    694         operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER;
    695 
    696       spv_operand_desc entry;
    697       if (grammar_.lookupOperand(type, word, &entry)) {
    698         return Diag(SPV_ERROR_INVALID_BINARY)
    699                << "Invalid " << spvOperandTypeStr(operand_.type)
    700                << " operand: " << word;
    701       }
    702 
    703       // Prepare to accept operands to this operand, if needed.
    704       spvPushOperandTypes(entry->operandTypes, expected_operands);
    705       break;
    706     }
    707 
    708     case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE:
    709     case SPV_OPERAND_TYPE_FUNCTION_CONTROL:
    710     case SPV_OPERAND_TYPE_LOOP_CONTROL:
    711     case SPV_OPERAND_TYPE_IMAGE:
    712     case SPV_OPERAND_TYPE_OPTIONAL_IMAGE:
    713     case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS:
    714     case SPV_OPERAND_TYPE_SELECTION_CONTROL: {
    715       // This operand is a mask.
    716       uint32_t word = 0;
    717       const spv_result_t result = DecodeNonIdWord(&word);
    718       if (result != SPV_SUCCESS) return result;
    719 
    720       inst_words_.push_back(word);
    721 
    722       // Map an optional operand type to its corresponding concrete type.
    723       if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE)
    724         operand_.type = SPV_OPERAND_TYPE_IMAGE;
    725       else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS)
    726         operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS;
    727 
    728       // Check validity of set mask bits. Also prepare for operands for those
    729       // masks if they have any.  To get operand order correct, scan from
    730       // MSB to LSB since we can only prepend operands to a pattern.
    731       // The only case in the grammar where you have more than one mask bit
    732       // having an operand is for image operands.  See SPIR-V 3.14 Image
    733       // Operands.
    734       uint32_t remaining_word = word;
    735       for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) {
    736         if (remaining_word & mask) {
    737           spv_operand_desc entry;
    738           if (grammar_.lookupOperand(type, mask, &entry)) {
    739             return Diag(SPV_ERROR_INVALID_BINARY)
    740                    << "Invalid " << spvOperandTypeStr(operand_.type)
    741                    << " operand: " << word << " has invalid mask component "
    742                    << mask;
    743           }
    744           remaining_word ^= mask;
    745           spvPushOperandTypes(entry->operandTypes, expected_operands);
    746         }
    747       }
    748       if (word == 0) {
    749         // An all-zeroes mask *might* also be valid.
    750         spv_operand_desc entry;
    751         if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) {
    752           // Prepare for its operands, if any.
    753           spvPushOperandTypes(entry->operandTypes, expected_operands);
    754         }
    755       }
    756       break;
    757     }
    758     default:
    759       return Diag(SPV_ERROR_INVALID_BINARY)
    760              << "Internal error: Unhandled operand type: " << type;
    761   }
    762 
    763   operand_.num_words = uint16_t(inst_words_.size() - first_word_index);
    764 
    765   assert(spvOperandIsConcrete(operand_.type));
    766 
    767   parsed_operands_.push_back(operand_);
    768 
    769   return SPV_SUCCESS;
    770 }
    771 
    772 spv_result_t MarkvDecoder::DecodeInstruction() {
    773   parsed_operands_.clear();
    774   inst_words_.clear();
    775 
    776   // Opcode/num_words placeholder, the word will be filled in later.
    777   inst_words_.push_back(0);
    778 
    779   bool num_operands_still_unknown = true;
    780   {
    781     uint32_t opcode = 0;
    782     uint32_t num_operands = 0;
    783 
    784     const spv_result_t opcode_decoding_result =
    785         DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands);
    786     if (opcode_decoding_result < 0) return opcode_decoding_result;
    787 
    788     if (opcode_decoding_result == SPV_SUCCESS) {
    789       inst_.num_operands = static_cast<uint16_t>(num_operands);
    790       num_operands_still_unknown = false;
    791     } else {
    792       if (!reader_.ReadVariableWidthU32(&opcode,
    793                                         model_->opcode_chunk_length())) {
    794         return Diag(SPV_ERROR_INVALID_BINARY)
    795                << "Failed to read opcode of instruction";
    796       }
    797     }
    798 
    799     inst_.opcode = static_cast<uint16_t>(opcode);
    800   }
    801 
    802   const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
    803 
    804   spv_opcode_desc opcode_desc;
    805   if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) {
    806     return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode";
    807   }
    808 
    809   spv_operand_pattern_t expected_operands;
    810   expected_operands.reserve(opcode_desc->numTypes);
    811   for (auto i = 0; i < opcode_desc->numTypes; i++) {
    812     expected_operands.push_back(
    813         opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]);
    814   }
    815 
    816   if (num_operands_still_unknown) {
    817     if (!OpcodeHasFixedNumberOfOperands(opcode)) {
    818       if (!reader_.ReadVariableWidthU16(&inst_.num_operands,
    819                                         model_->num_operands_chunk_length()))
    820         return Diag(SPV_ERROR_INVALID_BINARY)
    821                << "Failed to read num_operands of instruction";
    822     } else {
    823       inst_.num_operands = static_cast<uint16_t>(expected_operands.size());
    824     }
    825   }
    826 
    827   for (operand_index_ = 0;
    828        operand_index_ < static_cast<size_t>(inst_.num_operands);
    829        ++operand_index_) {
    830     assert(!expected_operands.empty());
    831     const spv_operand_type_t type =
    832         spvTakeFirstMatchableOperand(&expected_operands);
    833 
    834     const size_t operand_offset = inst_words_.size();
    835 
    836     const spv_result_t decode_result =
    837         DecodeOperand(operand_offset, type, &expected_operands);
    838 
    839     if (decode_result != SPV_SUCCESS) return decode_result;
    840   }
    841 
    842   assert(inst_.num_operands == parsed_operands_.size());
    843 
    844   // Only valid while inst_words_ and parsed_operands_ remain unchanged (until
    845   // next DecodeInstruction call).
    846   inst_.words = inst_words_.data();
    847   inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data();
    848   inst_.num_words = static_cast<uint16_t>(inst_words_.size());
    849   inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode));
    850 
    851   std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_));
    852 
    853   assert(inst_.num_words ==
    854              std::accumulate(
    855                  parsed_operands_.begin(), parsed_operands_.end(), 1,
    856                  [](int num_words, const spv_parsed_operand_t& operand) {
    857                    return num_words += operand.num_words;
    858                  }) &&
    859          "num_words in instruction doesn't correspond to the sum of num_words"
    860          "in the operands");
    861 
    862   RecordNumberType();
    863   ProcessCurInstruction();
    864 
    865   if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte))
    866     return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break";
    867 
    868   if (logger_) {
    869     logger_->NewLine();
    870     std::stringstream ss;
    871     ss << spvOpcodeString(opcode) << " ";
    872     for (size_t index = 1; index < inst_words_.size(); ++index)
    873       ss << inst_words_[index] << " ";
    874     logger_->AppendText(ss.str());
    875     logger_->NewLine();
    876     logger_->NewLine();
    877     if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION;
    878   }
    879 
    880   return SPV_SUCCESS;
    881 }
    882 
    883 spv_result_t MarkvDecoder::SetNumericTypeInfoForType(
    884     spv_parsed_operand_t* parsed_operand, uint32_t type_id) {
    885   assert(type_id != 0);
    886   auto type_info_iter = type_id_to_number_type_info_.find(type_id);
    887   if (type_info_iter == type_id_to_number_type_info_.end()) {
    888     return Diag(SPV_ERROR_INVALID_BINARY)
    889            << "Type Id " << type_id << " is not a type";
    890   }
    891 
    892   const NumberType& info = type_info_iter->second;
    893   if (info.type == SPV_NUMBER_NONE) {
    894     // This is a valid type, but for something other than a scalar number.
    895     return Diag(SPV_ERROR_INVALID_BINARY)
    896            << "Type Id " << type_id << " is not a scalar numeric type";
    897   }
    898 
    899   parsed_operand->number_kind = info.type;
    900   parsed_operand->number_bit_width = info.bit_width;
    901   // Round up the word count.
    902   parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32);
    903   return SPV_SUCCESS;
    904 }
    905 
    906 void MarkvDecoder::RecordNumberType() {
    907   const SpvOp opcode = static_cast<SpvOp>(inst_.opcode);
    908   if (spvOpcodeGeneratesType(opcode)) {
    909     NumberType info = {SPV_NUMBER_NONE, 0};
    910     if (SpvOpTypeInt == opcode) {
    911       info.bit_width = inst_.words[inst_.operands[1].offset];
    912       info.type = inst_.words[inst_.operands[2].offset]
    913                       ? SPV_NUMBER_SIGNED_INT
    914                       : SPV_NUMBER_UNSIGNED_INT;
    915     } else if (SpvOpTypeFloat == opcode) {
    916       info.bit_width = inst_.words[inst_.operands[1].offset];
    917       info.type = SPV_NUMBER_FLOATING;
    918     }
    919     // The *result* Id of a type generating instruction is the type Id.
    920     type_id_to_number_type_info_[inst_.result_id] = info;
    921   }
    922 }
    923 
    924 }  // namespace comp
    925 }  // namespace spvtools
    926