1 // Copyright (c) 2018 Google LLC 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 #include "source/comp/markv_decoder.h" 16 17 #include <cstring> 18 #include <iterator> 19 #include <numeric> 20 21 #include "source/ext_inst.h" 22 #include "source/opcode.h" 23 #include "spirv-tools/libspirv.hpp" 24 25 namespace spvtools { 26 namespace comp { 27 28 spv_result_t MarkvDecoder::DecodeNonIdWord(uint32_t* word) { 29 auto* codec = model_->GetNonIdWordHuffmanCodec(inst_.opcode, operand_index_); 30 31 if (codec) { 32 uint64_t decoded_value = 0; 33 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) 34 return Diag(SPV_ERROR_INVALID_BINARY) 35 << "Failed to decode non-id word with Huffman"; 36 37 if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { 38 // The word decoded successfully. 39 *word = uint32_t(decoded_value); 40 assert(*word == decoded_value); 41 return SPV_SUCCESS; 42 } 43 44 // Received kMarkvNoneOfTheAbove signal, use fallback decoding. 45 } 46 47 const size_t chunk_length = 48 model_->GetOperandVariableWidthChunkLength(operand_.type); 49 if (chunk_length) { 50 if (!reader_.ReadVariableWidthU32(word, chunk_length)) 51 return Diag(SPV_ERROR_INVALID_BINARY) 52 << "Failed to decode non-id word with varint"; 53 } else { 54 if (!reader_.ReadUnencoded(word)) 55 return Diag(SPV_ERROR_INVALID_BINARY) 56 << "Failed to read unencoded non-id word"; 57 } 58 return SPV_SUCCESS; 59 } 60 61 spv_result_t MarkvDecoder::DecodeOpcodeAndNumberOfOperands( 62 uint32_t* opcode, uint32_t* num_operands) { 63 // First try to use the Markov chain codec. 64 auto* codec = 65 model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(GetPrevOpcode()); 66 if (codec) { 67 uint64_t decoded_value = 0; 68 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) 69 return Diag(SPV_ERROR_INTERNAL) 70 << "Failed to decode opcode_and_num_operands, previous opcode is " 71 << spvOpcodeString(GetPrevOpcode()); 72 73 if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { 74 // The word was successfully decoded. 75 *opcode = uint32_t(decoded_value & 0xFFFF); 76 *num_operands = uint32_t(decoded_value >> 16); 77 return SPV_SUCCESS; 78 } 79 80 // Received kMarkvNoneOfTheAbove signal, use fallback decoding. 81 } 82 83 // Fallback to base-rate codec. 84 codec = model_->GetOpcodeAndNumOperandsMarkovHuffmanCodec(SpvOpNop); 85 assert(codec); 86 uint64_t decoded_value = 0; 87 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) 88 return Diag(SPV_ERROR_INTERNAL) 89 << "Failed to decode opcode_and_num_operands with global codec"; 90 91 if (decoded_value == MarkvModel::GetMarkvNoneOfTheAbove()) { 92 // Received kMarkvNoneOfTheAbove signal, fallback further. 93 return SPV_UNSUPPORTED; 94 } 95 96 *opcode = uint32_t(decoded_value & 0xFFFF); 97 *num_operands = uint32_t(decoded_value >> 16); 98 return SPV_SUCCESS; 99 } 100 101 spv_result_t MarkvDecoder::DecodeMtfRankHuffman(uint64_t mtf, 102 uint32_t fallback_method, 103 uint32_t* rank) { 104 const auto* codec = GetMtfHuffmanCodec(mtf); 105 if (!codec) { 106 assert(fallback_method != kMtfNone); 107 codec = GetMtfHuffmanCodec(fallback_method); 108 } 109 110 if (!codec) return Diag(SPV_ERROR_INTERNAL) << "No codec to decode MTF rank"; 111 112 uint32_t decoded_value = 0; 113 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) 114 return Diag(SPV_ERROR_INTERNAL) << "Failed to decode MTF rank with Huffman"; 115 116 if (decoded_value == kMtfRankEncodedByValueSignal) { 117 // Decode by value. 118 if (!reader_.ReadVariableWidthU32(rank, model_->mtf_rank_chunk_length())) 119 return Diag(SPV_ERROR_INTERNAL) 120 << "Failed to decode MTF rank with varint"; 121 *rank += MarkvCodec::kMtfSmallestRankEncodedByValue; 122 } else { 123 // Decode using Huffman coding. 124 assert(decoded_value < MarkvCodec::kMtfSmallestRankEncodedByValue); 125 *rank = decoded_value; 126 } 127 return SPV_SUCCESS; 128 } 129 130 spv_result_t MarkvDecoder::DecodeIdWithDescriptor(uint32_t* id) { 131 auto* codec = 132 model_->GetIdDescriptorHuffmanCodec(inst_.opcode, operand_index_); 133 134 uint64_t mtf = kMtfNone; 135 if (codec) { 136 uint64_t decoded_value = 0; 137 if (!codec->DecodeFromStream(GetReadBitCallback(), &decoded_value)) 138 return Diag(SPV_ERROR_INTERNAL) 139 << "Failed to decode descriptor with Huffman"; 140 141 if (decoded_value != MarkvModel::GetMarkvNoneOfTheAbove()) { 142 const uint32_t long_descriptor = uint32_t(decoded_value); 143 mtf = GetMtfLongIdDescriptor(long_descriptor); 144 } 145 } 146 147 if (mtf == kMtfNone) { 148 if (model_->id_fallback_strategy() != 149 MarkvModel::IdFallbackStrategy::kShortDescriptor) { 150 return SPV_UNSUPPORTED; 151 } 152 153 uint64_t decoded_value = 0; 154 if (!reader_.ReadBits(&decoded_value, MarkvCodec::kShortDescriptorNumBits)) 155 return Diag(SPV_ERROR_INTERNAL) << "Failed to read short descriptor"; 156 const uint32_t short_descriptor = uint32_t(decoded_value); 157 if (short_descriptor == 0) { 158 // Forward declared id. 159 return SPV_UNSUPPORTED; 160 } 161 mtf = GetMtfShortIdDescriptor(short_descriptor); 162 } 163 164 return DecodeExistingId(mtf, id); 165 } 166 167 spv_result_t MarkvDecoder::DecodeExistingId(uint64_t mtf, uint32_t* id) { 168 assert(multi_mtf_.GetSize(mtf) > 0); 169 *id = 0; 170 171 uint32_t rank = 0; 172 173 if (multi_mtf_.GetSize(mtf) == 1) { 174 rank = 1; 175 } else { 176 const spv_result_t result = 177 DecodeMtfRankHuffman(mtf, kMtfGenericNonZeroRank, &rank); 178 if (result != SPV_SUCCESS) return result; 179 } 180 181 assert(rank); 182 if (!multi_mtf_.ValueFromRank(mtf, rank, id)) 183 return Diag(SPV_ERROR_INTERNAL) << "MTF rank is out of bounds"; 184 185 return SPV_SUCCESS; 186 } 187 188 spv_result_t MarkvDecoder::DecodeRefId(uint32_t* id) { 189 { 190 const spv_result_t result = DecodeIdWithDescriptor(id); 191 if (result != SPV_UNSUPPORTED) return result; 192 } 193 194 const bool can_forward_declare = spvOperandCanBeForwardDeclaredFunction( 195 SpvOp(inst_.opcode))(operand_index_); 196 uint32_t rank = 0; 197 *id = 0; 198 199 if (model_->id_fallback_strategy() == 200 MarkvModel::IdFallbackStrategy::kRuleBased) { 201 uint64_t mtf = GetRuleBasedMtf(); 202 if (mtf != kMtfNone && !can_forward_declare) { 203 return DecodeExistingId(mtf, id); 204 } 205 206 if (mtf == kMtfNone) mtf = kMtfAll; 207 { 208 const spv_result_t result = DecodeMtfRankHuffman(mtf, kMtfAll, &rank); 209 if (result != SPV_SUCCESS) return result; 210 } 211 212 if (rank == 0) { 213 // This is the first occurrence of a forward declared id. 214 *id = GetIdBound(); 215 SetIdBound(*id + 1); 216 multi_mtf_.Insert(kMtfAll, *id); 217 multi_mtf_.Insert(kMtfForwardDeclared, *id); 218 if (mtf != kMtfAll) multi_mtf_.Insert(mtf, *id); 219 } else { 220 if (!multi_mtf_.ValueFromRank(mtf, rank, id)) 221 return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; 222 } 223 } else { 224 assert(can_forward_declare); 225 226 if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) 227 return Diag(SPV_ERROR_INTERNAL) 228 << "Failed to decode MTF rank with varint"; 229 230 if (rank == 0) { 231 // This is the first occurrence of a forward declared id. 232 *id = GetIdBound(); 233 SetIdBound(*id + 1); 234 multi_mtf_.Insert(kMtfForwardDeclared, *id); 235 } else { 236 if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, id)) 237 return Diag(SPV_ERROR_INTERNAL) << "MTF rank out of bounds"; 238 } 239 } 240 assert(*id); 241 return SPV_SUCCESS; 242 } 243 244 spv_result_t MarkvDecoder::DecodeTypeId() { 245 if (inst_.opcode == SpvOpFunctionParameter) { 246 assert(!remaining_function_parameter_types_.empty()); 247 inst_.type_id = remaining_function_parameter_types_.front(); 248 remaining_function_parameter_types_.pop_front(); 249 return SPV_SUCCESS; 250 } 251 252 { 253 const spv_result_t result = DecodeIdWithDescriptor(&inst_.type_id); 254 if (result != SPV_UNSUPPORTED) return result; 255 } 256 257 assert(model_->id_fallback_strategy() == 258 MarkvModel::IdFallbackStrategy::kRuleBased); 259 260 uint64_t mtf = GetRuleBasedMtf(); 261 assert(!spvOperandCanBeForwardDeclaredFunction(SpvOp(inst_.opcode))( 262 operand_index_)); 263 264 if (mtf == kMtfNone) { 265 mtf = kMtfTypeNonFunction; 266 // Function types should have been handled by GetRuleBasedMtf. 267 assert(inst_.opcode != SpvOpFunction); 268 } 269 270 return DecodeExistingId(mtf, &inst_.type_id); 271 } 272 273 spv_result_t MarkvDecoder::DecodeResultId() { 274 uint32_t rank = 0; 275 276 const uint64_t num_still_forward_declared = 277 multi_mtf_.GetSize(kMtfForwardDeclared); 278 279 if (num_still_forward_declared) { 280 // Some ids were forward declared. Check if this id is one of them. 281 uint64_t id_was_forward_declared; 282 if (!reader_.ReadBits(&id_was_forward_declared, 1)) 283 return Diag(SPV_ERROR_INVALID_BINARY) 284 << "Failed to read id_was_forward_declared flag"; 285 286 if (id_was_forward_declared) { 287 if (!reader_.ReadVariableWidthU32(&rank, model_->mtf_rank_chunk_length())) 288 return Diag(SPV_ERROR_INVALID_BINARY) 289 << "Failed to read MTF rank of forward declared id"; 290 291 if (rank) { 292 // The id was forward declared, recover it from kMtfForwardDeclared. 293 if (!multi_mtf_.ValueFromRank(kMtfForwardDeclared, rank, 294 &inst_.result_id)) 295 return Diag(SPV_ERROR_INTERNAL) 296 << "Forward declared MTF rank is out of bounds"; 297 298 // We can now remove the id from kMtfForwardDeclared. 299 if (!multi_mtf_.Remove(kMtfForwardDeclared, inst_.result_id)) 300 return Diag(SPV_ERROR_INTERNAL) 301 << "Failed to remove id from kMtfForwardDeclared"; 302 } 303 } 304 } 305 306 if (inst_.result_id == 0) { 307 // The id was not forward declared, issue a new id. 308 inst_.result_id = GetIdBound(); 309 SetIdBound(inst_.result_id + 1); 310 } 311 312 if (model_->id_fallback_strategy() == 313 MarkvModel::IdFallbackStrategy::kRuleBased) { 314 if (!rank) { 315 multi_mtf_.Insert(kMtfAll, inst_.result_id); 316 } 317 } 318 319 return SPV_SUCCESS; 320 } 321 322 spv_result_t MarkvDecoder::DecodeLiteralNumber( 323 const spv_parsed_operand_t& operand) { 324 if (operand.number_bit_width <= 32) { 325 uint32_t word = 0; 326 const spv_result_t result = DecodeNonIdWord(&word); 327 if (result != SPV_SUCCESS) return result; 328 inst_words_.push_back(word); 329 } else { 330 assert(operand.number_bit_width <= 64); 331 uint64_t word = 0; 332 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { 333 if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) 334 return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal U64"; 335 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { 336 int64_t val = 0; 337 if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), 338 model_->s64_block_exponent())) 339 return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal S64"; 340 std::memcpy(&word, &val, 8); 341 } else if (operand.number_kind == SPV_NUMBER_FLOATING) { 342 if (!reader_.ReadUnencoded(&word)) 343 return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read literal F64"; 344 } else { 345 return Diag(SPV_ERROR_INTERNAL) << "Unsupported bit length"; 346 } 347 inst_words_.push_back(static_cast<uint32_t>(word)); 348 inst_words_.push_back(static_cast<uint32_t>(word >> 32)); 349 } 350 return SPV_SUCCESS; 351 } 352 353 bool MarkvDecoder::ReadToByteBreak(size_t byte_break_if_less_than) { 354 const size_t num_bits_to_next_byte = 355 GetNumBitsToNextByte(reader_.GetNumReadBits()); 356 if (num_bits_to_next_byte == 0 || 357 num_bits_to_next_byte > byte_break_if_less_than) 358 return true; 359 360 uint64_t bits = 0; 361 if (!reader_.ReadBits(&bits, num_bits_to_next_byte)) return false; 362 363 assert(bits == 0); 364 if (bits != 0) return false; 365 366 return true; 367 } 368 369 spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) { 370 const bool header_read_success = 371 reader_.ReadUnencoded(&header_.magic_number) && 372 reader_.ReadUnencoded(&header_.markv_version) && 373 reader_.ReadUnencoded(&header_.markv_model) && 374 reader_.ReadUnencoded(&header_.markv_length_in_bits) && 375 reader_.ReadUnencoded(&header_.spirv_version) && 376 reader_.ReadUnencoded(&header_.spirv_generator); 377 378 if (!header_read_success) 379 return Diag(SPV_ERROR_INVALID_BINARY) << "Unable to read MARK-V header"; 380 381 if (header_.markv_length_in_bits == 0) 382 return Diag(SPV_ERROR_INVALID_BINARY) 383 << "Header markv_length_in_bits field is zero"; 384 385 if (header_.magic_number != MarkvCodec::kMarkvMagicNumber) 386 return Diag(SPV_ERROR_INVALID_BINARY) 387 << "MARK-V binary has incorrect magic number"; 388 389 // TODO(atgoo (at) github.com): Print version strings. 390 if (header_.markv_version != MarkvCodec::GetMarkvVersion()) 391 return Diag(SPV_ERROR_INVALID_BINARY) 392 << "MARK-V binary and the codec have different versions"; 393 394 const uint32_t model_type = header_.markv_model >> 16; 395 const uint32_t model_version = header_.markv_model & 0xFFFF; 396 if (model_type != model_->model_type()) 397 return Diag(SPV_ERROR_INVALID_BINARY) 398 << "MARK-V binary and the codec use different MARK-V models"; 399 400 if (model_version != model_->model_version()) 401 return Diag(SPV_ERROR_INVALID_BINARY) 402 << "MARK-V binary and the codec use different versions if the same " 403 << "MARK-V model"; 404 405 spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. 406 spirv_.resize(5, 0); 407 spirv_[0] = SpvMagicNumber; 408 spirv_[1] = header_.spirv_version; 409 spirv_[2] = header_.spirv_generator; 410 411 if (logger_) { 412 reader_.SetCallback( 413 [this](const std::string& str) { logger_->AppendBitSequence(str); }); 414 } 415 416 while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { 417 inst_ = {}; 418 const spv_result_t decode_result = DecodeInstruction(); 419 if (decode_result != SPV_SUCCESS) return decode_result; 420 } 421 422 if (validator_options_) { 423 spv_const_binary_t validation_binary = {spirv_.data(), spirv_.size()}; 424 const spv_result_t result = spvValidateWithOptions( 425 context_, validator_options_, &validation_binary, nullptr); 426 if (result != SPV_SUCCESS) return result; 427 } 428 429 // Validate the decode binary 430 if (reader_.GetNumReadBits() != header_.markv_length_in_bits || 431 !reader_.OnlyZeroesLeft()) { 432 return Diag(SPV_ERROR_INVALID_BINARY) 433 << "MARK-V binary has wrong stated bit length " 434 << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; 435 } 436 437 // Decoding of the module is finished, validation state should have correct 438 // id bound. 439 spirv_[3] = GetIdBound(); 440 441 *spirv_binary = std::move(spirv_); 442 return SPV_SUCCESS; 443 } 444 445 // TODO(atgoo (at) github.com): The implementation borrows heavily from 446 // Parser::parseOperand. 447 // Consider coupling them together in some way once MARK-V codec is more mature. 448 // For now it's better to keep the code independent for experimentation 449 // purposes. 450 spv_result_t MarkvDecoder::DecodeOperand( 451 size_t operand_offset, const spv_operand_type_t type, 452 spv_operand_pattern_t* expected_operands) { 453 const SpvOp opcode = static_cast<SpvOp>(inst_.opcode); 454 455 memset(&operand_, 0, sizeof(operand_)); 456 457 assert((operand_offset >> 16) == 0); 458 operand_.offset = static_cast<uint16_t>(operand_offset); 459 operand_.type = type; 460 461 // Set default values, may be updated later. 462 operand_.number_kind = SPV_NUMBER_NONE; 463 operand_.number_bit_width = 0; 464 465 const size_t first_word_index = inst_words_.size(); 466 467 switch (type) { 468 case SPV_OPERAND_TYPE_RESULT_ID: { 469 const spv_result_t result = DecodeResultId(); 470 if (result != SPV_SUCCESS) return result; 471 472 inst_words_.push_back(inst_.result_id); 473 SetIdBound(std::max(GetIdBound(), inst_.result_id + 1)); 474 PromoteIfNeeded(inst_.result_id); 475 break; 476 } 477 478 case SPV_OPERAND_TYPE_TYPE_ID: { 479 const spv_result_t result = DecodeTypeId(); 480 if (result != SPV_SUCCESS) return result; 481 482 inst_words_.push_back(inst_.type_id); 483 SetIdBound(std::max(GetIdBound(), inst_.type_id + 1)); 484 PromoteIfNeeded(inst_.type_id); 485 break; 486 } 487 488 case SPV_OPERAND_TYPE_ID: 489 case SPV_OPERAND_TYPE_OPTIONAL_ID: 490 case SPV_OPERAND_TYPE_SCOPE_ID: 491 case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { 492 uint32_t id = 0; 493 const spv_result_t result = DecodeRefId(&id); 494 if (result != SPV_SUCCESS) return result; 495 496 if (id == 0) return Diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; 497 498 if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { 499 operand_.type = SPV_OPERAND_TYPE_ID; 500 501 if (opcode == SpvOpExtInst && operand_.offset == 3) { 502 // The current word is the extended instruction set id. 503 // Set the extended instruction set type for the current 504 // instruction. 505 auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); 506 if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { 507 return Diag(SPV_ERROR_INVALID_ID) 508 << "OpExtInst set id " << id 509 << " does not reference an OpExtInstImport result Id"; 510 } 511 inst_.ext_inst_type = ext_inst_type_iter->second; 512 } 513 } 514 515 inst_words_.push_back(id); 516 SetIdBound(std::max(GetIdBound(), id + 1)); 517 PromoteIfNeeded(id); 518 break; 519 } 520 521 case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { 522 uint32_t word = 0; 523 const spv_result_t result = DecodeNonIdWord(&word); 524 if (result != SPV_SUCCESS) return result; 525 526 inst_words_.push_back(word); 527 528 assert(SpvOpExtInst == opcode); 529 assert(inst_.ext_inst_type != SPV_EXT_INST_TYPE_NONE); 530 spv_ext_inst_desc ext_inst; 531 if (grammar_.lookupExtInst(inst_.ext_inst_type, word, &ext_inst)) 532 return Diag(SPV_ERROR_INVALID_BINARY) 533 << "Invalid extended instruction number: " << word; 534 spvPushOperandTypes(ext_inst->operandTypes, expected_operands); 535 break; 536 } 537 538 case SPV_OPERAND_TYPE_LITERAL_INTEGER: 539 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { 540 // These are regular single-word literal integer operands. 541 // Post-parsing validation should check the range of the parsed value. 542 operand_.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; 543 // It turns out they are always unsigned integers! 544 operand_.number_kind = SPV_NUMBER_UNSIGNED_INT; 545 operand_.number_bit_width = 32; 546 547 uint32_t word = 0; 548 const spv_result_t result = DecodeNonIdWord(&word); 549 if (result != SPV_SUCCESS) return result; 550 551 inst_words_.push_back(word); 552 break; 553 } 554 555 case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: 556 case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: { 557 operand_.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; 558 if (opcode == SpvOpSwitch) { 559 // The literal operands have the same type as the value 560 // referenced by the selector Id. 561 const uint32_t selector_id = inst_words_.at(1); 562 const auto type_id_iter = id_to_type_id_.find(selector_id); 563 if (type_id_iter == id_to_type_id_.end() || type_id_iter->second == 0) { 564 return Diag(SPV_ERROR_INVALID_BINARY) 565 << "Invalid OpSwitch: selector id " << selector_id 566 << " has no type"; 567 } 568 uint32_t type_id = type_id_iter->second; 569 570 if (selector_id == type_id) { 571 // Recall that by convention, a result ID that is a type definition 572 // maps to itself. 573 return Diag(SPV_ERROR_INVALID_BINARY) 574 << "Invalid OpSwitch: selector id " << selector_id 575 << " is a type, not a value"; 576 } 577 if (auto error = SetNumericTypeInfoForType(&operand_, type_id)) 578 return error; 579 if (operand_.number_kind != SPV_NUMBER_UNSIGNED_INT && 580 operand_.number_kind != SPV_NUMBER_SIGNED_INT) { 581 return Diag(SPV_ERROR_INVALID_BINARY) 582 << "Invalid OpSwitch: selector id " << selector_id 583 << " is not a scalar integer"; 584 } 585 } else { 586 assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); 587 // The literal number type is determined by the type Id for the 588 // constant. 589 assert(inst_.type_id); 590 if (auto error = SetNumericTypeInfoForType(&operand_, inst_.type_id)) 591 return error; 592 } 593 594 if (auto error = DecodeLiteralNumber(operand_)) return error; 595 596 break; 597 } 598 599 case SPV_OPERAND_TYPE_LITERAL_STRING: 600 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { 601 operand_.type = SPV_OPERAND_TYPE_LITERAL_STRING; 602 std::vector<char> str; 603 auto* codec = model_->GetLiteralStringHuffmanCodec(inst_.opcode); 604 605 if (codec) { 606 std::string decoded_string; 607 const bool huffman_result = 608 codec->DecodeFromStream(GetReadBitCallback(), &decoded_string); 609 assert(huffman_result); 610 if (!huffman_result) 611 return Diag(SPV_ERROR_INVALID_BINARY) 612 << "Failed to read literal string"; 613 614 if (decoded_string != "kMarkvNoneOfTheAbove") { 615 std::copy(decoded_string.begin(), decoded_string.end(), 616 std::back_inserter(str)); 617 str.push_back('\0'); 618 } 619 } 620 621 // The loop is expected to terminate once we encounter '\0' or exhaust 622 // the bit stream. 623 if (str.empty()) { 624 while (true) { 625 char ch = 0; 626 if (!reader_.ReadUnencoded(&ch)) 627 return Diag(SPV_ERROR_INVALID_BINARY) 628 << "Failed to read literal string"; 629 630 str.push_back(ch); 631 632 if (ch == '\0') break; 633 } 634 } 635 636 while (str.size() % 4 != 0) str.push_back('\0'); 637 638 inst_words_.resize(inst_words_.size() + str.size() / 4); 639 std::memcpy(&inst_words_[first_word_index], str.data(), str.size()); 640 641 if (SpvOpExtInstImport == opcode) { 642 // Record the extended instruction type for the ID for this import. 643 // There is only one string literal argument to OpExtInstImport, 644 // so it's sufficient to guard this just on the opcode. 645 const spv_ext_inst_type_t ext_inst_type = 646 spvExtInstImportTypeGet(str.data()); 647 if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { 648 return Diag(SPV_ERROR_INVALID_BINARY) 649 << "Invalid extended instruction import '" << str.data() 650 << "'"; 651 } 652 // We must have parsed a valid result ID. It's a condition 653 // of the grammar, and we only accept non-zero result Ids. 654 assert(inst_.result_id); 655 const bool inserted = 656 import_id_to_ext_inst_type_.emplace(inst_.result_id, ext_inst_type) 657 .second; 658 (void)inserted; 659 assert(inserted); 660 } 661 break; 662 } 663 664 case SPV_OPERAND_TYPE_CAPABILITY: 665 case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: 666 case SPV_OPERAND_TYPE_EXECUTION_MODEL: 667 case SPV_OPERAND_TYPE_ADDRESSING_MODEL: 668 case SPV_OPERAND_TYPE_MEMORY_MODEL: 669 case SPV_OPERAND_TYPE_EXECUTION_MODE: 670 case SPV_OPERAND_TYPE_STORAGE_CLASS: 671 case SPV_OPERAND_TYPE_DIMENSIONALITY: 672 case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: 673 case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: 674 case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: 675 case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: 676 case SPV_OPERAND_TYPE_LINKAGE_TYPE: 677 case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: 678 case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: 679 case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: 680 case SPV_OPERAND_TYPE_DECORATION: 681 case SPV_OPERAND_TYPE_BUILT_IN: 682 case SPV_OPERAND_TYPE_GROUP_OPERATION: 683 case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: 684 case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { 685 // A single word that is a plain enum value. 686 uint32_t word = 0; 687 const spv_result_t result = DecodeNonIdWord(&word); 688 if (result != SPV_SUCCESS) return result; 689 690 inst_words_.push_back(word); 691 692 // Map an optional operand type to its corresponding concrete type. 693 if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) 694 operand_.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; 695 696 spv_operand_desc entry; 697 if (grammar_.lookupOperand(type, word, &entry)) { 698 return Diag(SPV_ERROR_INVALID_BINARY) 699 << "Invalid " << spvOperandTypeStr(operand_.type) 700 << " operand: " << word; 701 } 702 703 // Prepare to accept operands to this operand, if needed. 704 spvPushOperandTypes(entry->operandTypes, expected_operands); 705 break; 706 } 707 708 case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: 709 case SPV_OPERAND_TYPE_FUNCTION_CONTROL: 710 case SPV_OPERAND_TYPE_LOOP_CONTROL: 711 case SPV_OPERAND_TYPE_IMAGE: 712 case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: 713 case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: 714 case SPV_OPERAND_TYPE_SELECTION_CONTROL: { 715 // This operand is a mask. 716 uint32_t word = 0; 717 const spv_result_t result = DecodeNonIdWord(&word); 718 if (result != SPV_SUCCESS) return result; 719 720 inst_words_.push_back(word); 721 722 // Map an optional operand type to its corresponding concrete type. 723 if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) 724 operand_.type = SPV_OPERAND_TYPE_IMAGE; 725 else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) 726 operand_.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; 727 728 // Check validity of set mask bits. Also prepare for operands for those 729 // masks if they have any. To get operand order correct, scan from 730 // MSB to LSB since we can only prepend operands to a pattern. 731 // The only case in the grammar where you have more than one mask bit 732 // having an operand is for image operands. See SPIR-V 3.14 Image 733 // Operands. 734 uint32_t remaining_word = word; 735 for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { 736 if (remaining_word & mask) { 737 spv_operand_desc entry; 738 if (grammar_.lookupOperand(type, mask, &entry)) { 739 return Diag(SPV_ERROR_INVALID_BINARY) 740 << "Invalid " << spvOperandTypeStr(operand_.type) 741 << " operand: " << word << " has invalid mask component " 742 << mask; 743 } 744 remaining_word ^= mask; 745 spvPushOperandTypes(entry->operandTypes, expected_operands); 746 } 747 } 748 if (word == 0) { 749 // An all-zeroes mask *might* also be valid. 750 spv_operand_desc entry; 751 if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { 752 // Prepare for its operands, if any. 753 spvPushOperandTypes(entry->operandTypes, expected_operands); 754 } 755 } 756 break; 757 } 758 default: 759 return Diag(SPV_ERROR_INVALID_BINARY) 760 << "Internal error: Unhandled operand type: " << type; 761 } 762 763 operand_.num_words = uint16_t(inst_words_.size() - first_word_index); 764 765 assert(spvOperandIsConcrete(operand_.type)); 766 767 parsed_operands_.push_back(operand_); 768 769 return SPV_SUCCESS; 770 } 771 772 spv_result_t MarkvDecoder::DecodeInstruction() { 773 parsed_operands_.clear(); 774 inst_words_.clear(); 775 776 // Opcode/num_words placeholder, the word will be filled in later. 777 inst_words_.push_back(0); 778 779 bool num_operands_still_unknown = true; 780 { 781 uint32_t opcode = 0; 782 uint32_t num_operands = 0; 783 784 const spv_result_t opcode_decoding_result = 785 DecodeOpcodeAndNumberOfOperands(&opcode, &num_operands); 786 if (opcode_decoding_result < 0) return opcode_decoding_result; 787 788 if (opcode_decoding_result == SPV_SUCCESS) { 789 inst_.num_operands = static_cast<uint16_t>(num_operands); 790 num_operands_still_unknown = false; 791 } else { 792 if (!reader_.ReadVariableWidthU32(&opcode, 793 model_->opcode_chunk_length())) { 794 return Diag(SPV_ERROR_INVALID_BINARY) 795 << "Failed to read opcode of instruction"; 796 } 797 } 798 799 inst_.opcode = static_cast<uint16_t>(opcode); 800 } 801 802 const SpvOp opcode = static_cast<SpvOp>(inst_.opcode); 803 804 spv_opcode_desc opcode_desc; 805 if (grammar_.lookupOpcode(opcode, &opcode_desc) != SPV_SUCCESS) { 806 return Diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; 807 } 808 809 spv_operand_pattern_t expected_operands; 810 expected_operands.reserve(opcode_desc->numTypes); 811 for (auto i = 0; i < opcode_desc->numTypes; i++) { 812 expected_operands.push_back( 813 opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); 814 } 815 816 if (num_operands_still_unknown) { 817 if (!OpcodeHasFixedNumberOfOperands(opcode)) { 818 if (!reader_.ReadVariableWidthU16(&inst_.num_operands, 819 model_->num_operands_chunk_length())) 820 return Diag(SPV_ERROR_INVALID_BINARY) 821 << "Failed to read num_operands of instruction"; 822 } else { 823 inst_.num_operands = static_cast<uint16_t>(expected_operands.size()); 824 } 825 } 826 827 for (operand_index_ = 0; 828 operand_index_ < static_cast<size_t>(inst_.num_operands); 829 ++operand_index_) { 830 assert(!expected_operands.empty()); 831 const spv_operand_type_t type = 832 spvTakeFirstMatchableOperand(&expected_operands); 833 834 const size_t operand_offset = inst_words_.size(); 835 836 const spv_result_t decode_result = 837 DecodeOperand(operand_offset, type, &expected_operands); 838 839 if (decode_result != SPV_SUCCESS) return decode_result; 840 } 841 842 assert(inst_.num_operands == parsed_operands_.size()); 843 844 // Only valid while inst_words_ and parsed_operands_ remain unchanged (until 845 // next DecodeInstruction call). 846 inst_.words = inst_words_.data(); 847 inst_.operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); 848 inst_.num_words = static_cast<uint16_t>(inst_words_.size()); 849 inst_words_[0] = spvOpcodeMake(inst_.num_words, SpvOp(inst_.opcode)); 850 851 std::copy(inst_words_.begin(), inst_words_.end(), std::back_inserter(spirv_)); 852 853 assert(inst_.num_words == 854 std::accumulate( 855 parsed_operands_.begin(), parsed_operands_.end(), 1, 856 [](int num_words, const spv_parsed_operand_t& operand) { 857 return num_words += operand.num_words; 858 }) && 859 "num_words in instruction doesn't correspond to the sum of num_words" 860 "in the operands"); 861 862 RecordNumberType(); 863 ProcessCurInstruction(); 864 865 if (!ReadToByteBreak(MarkvCodec::kByteBreakAfterInstIfLessThanUntilNextByte)) 866 return Diag(SPV_ERROR_INVALID_BINARY) << "Failed to read to byte break"; 867 868 if (logger_) { 869 logger_->NewLine(); 870 std::stringstream ss; 871 ss << spvOpcodeString(opcode) << " "; 872 for (size_t index = 1; index < inst_words_.size(); ++index) 873 ss << inst_words_[index] << " "; 874 logger_->AppendText(ss.str()); 875 logger_->NewLine(); 876 logger_->NewLine(); 877 if (!logger_->DebugInstruction(inst_)) return SPV_REQUESTED_TERMINATION; 878 } 879 880 return SPV_SUCCESS; 881 } 882 883 spv_result_t MarkvDecoder::SetNumericTypeInfoForType( 884 spv_parsed_operand_t* parsed_operand, uint32_t type_id) { 885 assert(type_id != 0); 886 auto type_info_iter = type_id_to_number_type_info_.find(type_id); 887 if (type_info_iter == type_id_to_number_type_info_.end()) { 888 return Diag(SPV_ERROR_INVALID_BINARY) 889 << "Type Id " << type_id << " is not a type"; 890 } 891 892 const NumberType& info = type_info_iter->second; 893 if (info.type == SPV_NUMBER_NONE) { 894 // This is a valid type, but for something other than a scalar number. 895 return Diag(SPV_ERROR_INVALID_BINARY) 896 << "Type Id " << type_id << " is not a scalar numeric type"; 897 } 898 899 parsed_operand->number_kind = info.type; 900 parsed_operand->number_bit_width = info.bit_width; 901 // Round up the word count. 902 parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32); 903 return SPV_SUCCESS; 904 } 905 906 void MarkvDecoder::RecordNumberType() { 907 const SpvOp opcode = static_cast<SpvOp>(inst_.opcode); 908 if (spvOpcodeGeneratesType(opcode)) { 909 NumberType info = {SPV_NUMBER_NONE, 0}; 910 if (SpvOpTypeInt == opcode) { 911 info.bit_width = inst_.words[inst_.operands[1].offset]; 912 info.type = inst_.words[inst_.operands[2].offset] 913 ? SPV_NUMBER_SIGNED_INT 914 : SPV_NUMBER_UNSIGNED_INT; 915 } else if (SpvOpTypeFloat == opcode) { 916 info.bit_width = inst_.words[inst_.operands[1].offset]; 917 info.type = SPV_NUMBER_FLOATING; 918 } 919 // The *result* Id of a type generating instruction is the type Id. 920 type_id_to_number_type_info_[inst_.result_id] = info; 921 } 922 } 923 924 } // namespace comp 925 } // namespace spvtools 926