1 // Copyright (c) 2017 Google Inc. 2 // 3 // Licensed under the Apache License, Version 2.0 (the "License"); 4 // you may not use this file except in compliance with the License. 5 // You may obtain a copy of the License at 6 // 7 // http://www.apache.org/licenses/LICENSE-2.0 8 // 9 // Unless required by applicable law or agreed to in writing, software 10 // distributed under the License is distributed on an "AS IS" BASIS, 11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 // See the License for the specific language governing permissions and 13 // limitations under the License. 14 15 // Contains 16 // - SPIR-V to MARK-V encoder 17 // - MARK-V to SPIR-V decoder 18 // 19 // MARK-V is a compression format for SPIR-V binaries. It strips away 20 // non-essential information (such as result ids which can be regenerated) and 21 // uses various bit reduction techiniques to reduce the size of the binary. 22 // 23 // MarkvModel is a flatbuffers object containing a set of rules defining how 24 // compression/decompression is done (coding schemes, dictionaries). 25 26 #include <algorithm> 27 #include <cassert> 28 #include <cstring> 29 #include <functional> 30 #include <iostream> 31 #include <list> 32 #include <memory> 33 #include <numeric> 34 #include <string> 35 #include <vector> 36 37 #include "binary.h" 38 #include "diagnostic.h" 39 #include "enum_string_mapping.h" 40 #include "extensions.h" 41 #include "ext_inst.h" 42 #include "instruction.h" 43 #include "opcode.h" 44 #include "operand.h" 45 #include "spirv-tools/libspirv.h" 46 #include "spirv-tools/markv.h" 47 #include "spirv_endian.h" 48 #include "spirv_validator_options.h" 49 #include "util/bit_stream.h" 50 #include "util/parse_number.h" 51 #include "validate.h" 52 #include "val/instruction.h" 53 #include "val/validation_state.h" 54 55 using libspirv::Instruction; 56 using libspirv::ValidationState_t; 57 using spvtools::ValidateInstructionAndUpdateValidationState; 58 using spvutils::BitReaderWord64; 59 using spvutils::BitWriterWord64; 60 61 struct spv_markv_encoder_options_t { 62 }; 63 64 struct spv_markv_decoder_options_t { 65 }; 66 67 namespace { 68 69 const uint32_t kSpirvMagicNumber = SpvMagicNumber; 70 const uint32_t kMarkvMagicNumber = 0x07230303; 71 72 enum { 73 kMarkvFirstOpcode = 65536, 74 kMarkvOpNextInstructionEncodesResultId = 65536, 75 }; 76 77 const size_t kCommentNumWhitespaces = 2; 78 79 // TODO(atgoo (at) github.com): This is a placeholder for an autogenerated flatbuffer 80 // containing MARK-V model for a specific dataset. 81 class MarkvModel { 82 public: 83 size_t opcode_chunk_length() const { return 7; } 84 size_t num_operands_chunk_length() const { return 3; } 85 size_t id_index_chunk_length() const { return 3; } 86 87 size_t u16_chunk_length() const { return 4; } 88 size_t s16_chunk_length() const { return 4; } 89 size_t s16_block_exponent() const { return 6; } 90 91 size_t u32_chunk_length() const { return 8; } 92 size_t s32_chunk_length() const { return 8; } 93 size_t s32_block_exponent() const { return 10; } 94 95 size_t u64_chunk_length() const { return 8; } 96 size_t s64_chunk_length() const { return 8; } 97 size_t s64_block_exponent() const { return 10; } 98 }; 99 100 const MarkvModel* GetDefaultModel() { 101 static MarkvModel model; 102 return &model; 103 } 104 105 // Returns chunk length used for variable length encoding of spirv operand 106 // words. Returns zero if operand type corresponds to potentially multiple 107 // words or a word which is not expected to profit from variable width encoding. 108 // Chunk length is selected based on the size of expected value. 109 // Most of these values will later be encoded with probability-based coding, 110 // but variable width integer coding is a good quick solution. 111 // TODO(atgoo (at) github.com): Put this in MarkvModel flatbuffer. 112 size_t GetOperandVariableWidthChunkLength(spv_operand_type_t type) { 113 switch (type) { 114 case SPV_OPERAND_TYPE_TYPE_ID: 115 return 4; 116 case SPV_OPERAND_TYPE_RESULT_ID: 117 case SPV_OPERAND_TYPE_ID: 118 case SPV_OPERAND_TYPE_SCOPE_ID: 119 case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: 120 return 8; 121 case SPV_OPERAND_TYPE_LITERAL_INTEGER: 122 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: 123 return 6; 124 case SPV_OPERAND_TYPE_CAPABILITY: 125 return 6; 126 case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: 127 case SPV_OPERAND_TYPE_EXECUTION_MODEL: 128 return 3; 129 case SPV_OPERAND_TYPE_ADDRESSING_MODEL: 130 case SPV_OPERAND_TYPE_MEMORY_MODEL: 131 return 2; 132 case SPV_OPERAND_TYPE_EXECUTION_MODE: 133 return 6; 134 case SPV_OPERAND_TYPE_STORAGE_CLASS: 135 return 4; 136 case SPV_OPERAND_TYPE_DIMENSIONALITY: 137 case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: 138 return 3; 139 case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: 140 return 2; 141 case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: 142 return 6; 143 case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: 144 case SPV_OPERAND_TYPE_LINKAGE_TYPE: 145 case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: 146 case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: 147 return 2; 148 case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: 149 return 3; 150 case SPV_OPERAND_TYPE_DECORATION: 151 case SPV_OPERAND_TYPE_BUILT_IN: 152 return 6; 153 case SPV_OPERAND_TYPE_GROUP_OPERATION: 154 case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: 155 case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: 156 return 2; 157 case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: 158 case SPV_OPERAND_TYPE_FUNCTION_CONTROL: 159 case SPV_OPERAND_TYPE_LOOP_CONTROL: 160 case SPV_OPERAND_TYPE_IMAGE: 161 case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: 162 case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: 163 case SPV_OPERAND_TYPE_SELECTION_CONTROL: 164 return 4; 165 case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: 166 return 6; 167 default: 168 return 0; 169 } 170 return 0; 171 } 172 173 // Returns true if the opcode has a fixed number of operands. May return a 174 // false negative. 175 bool OpcodeHasFixedNumberOfOperands(SpvOp opcode) { 176 switch (opcode) { 177 // TODO(atgoo (at) github.com) This is not a complete list. 178 case SpvOpNop: 179 case SpvOpName: 180 case SpvOpUndef: 181 case SpvOpSizeOf: 182 case SpvOpLine: 183 case SpvOpNoLine: 184 case SpvOpDecorationGroup: 185 case SpvOpExtension: 186 case SpvOpExtInstImport: 187 case SpvOpMemoryModel: 188 case SpvOpCapability: 189 case SpvOpTypeVoid: 190 case SpvOpTypeBool: 191 case SpvOpTypeInt: 192 case SpvOpTypeFloat: 193 case SpvOpTypeVector: 194 case SpvOpTypeMatrix: 195 case SpvOpTypeSampler: 196 case SpvOpTypeSampledImage: 197 case SpvOpTypeArray: 198 case SpvOpTypePointer: 199 case SpvOpConstantTrue: 200 case SpvOpConstantFalse: 201 case SpvOpLabel: 202 case SpvOpBranch: 203 case SpvOpFunction: 204 case SpvOpFunctionParameter: 205 case SpvOpFunctionEnd: 206 case SpvOpBitcast: 207 case SpvOpCopyObject: 208 case SpvOpTranspose: 209 case SpvOpSNegate: 210 case SpvOpFNegate: 211 case SpvOpIAdd: 212 case SpvOpFAdd: 213 case SpvOpISub: 214 case SpvOpFSub: 215 case SpvOpIMul: 216 case SpvOpFMul: 217 case SpvOpUDiv: 218 case SpvOpSDiv: 219 case SpvOpFDiv: 220 case SpvOpUMod: 221 case SpvOpSRem: 222 case SpvOpSMod: 223 case SpvOpFRem: 224 case SpvOpFMod: 225 case SpvOpVectorTimesScalar: 226 case SpvOpMatrixTimesScalar: 227 case SpvOpVectorTimesMatrix: 228 case SpvOpMatrixTimesVector: 229 case SpvOpMatrixTimesMatrix: 230 case SpvOpOuterProduct: 231 case SpvOpDot: 232 return true; 233 default: 234 break; 235 } 236 return false; 237 } 238 239 size_t GetNumBitsToNextByte(size_t bit_pos) { 240 return (8 - (bit_pos % 8)) % 8; 241 } 242 243 bool ShouldByteBreak(size_t bit_pos) { 244 const size_t num_bits_to_next_byte = GetNumBitsToNextByte(bit_pos); 245 return num_bits_to_next_byte > 0; // && num_bits_to_next_byte <= 2; 246 } 247 248 // Defines and returns current MARK-V version. 249 uint32_t GetMarkvVersion() { 250 const uint32_t kVersionMajor = 1; 251 const uint32_t kVersionMinor = 0; 252 return kVersionMinor | (kVersionMajor << 16); 253 } 254 255 class CommentLogger { 256 public: 257 void AppendText(const std::string& str) { 258 Append(str); 259 use_delimiter_ = false; 260 } 261 262 void AppendTextNewLine(const std::string& str) { 263 Append(str); 264 Append("\n"); 265 use_delimiter_ = false; 266 } 267 268 void AppendBitSequence(const std::string& str) { 269 if (use_delimiter_) 270 Append("-"); 271 Append(str); 272 use_delimiter_ = true; 273 } 274 275 void AppendWhitespaces(size_t num) { 276 Append(std::string(num, ' ')); 277 use_delimiter_ = false; 278 } 279 280 void NewLine() { 281 Append("\n"); 282 use_delimiter_ = false; 283 } 284 285 std::string GetText() const { 286 return ss_.str(); 287 } 288 289 private: 290 void Append(const std::string& str) { 291 ss_ << str; 292 // std::cerr << str; 293 } 294 295 std::stringstream ss_; 296 297 // If true a delimiter will be appended before the next bit sequence. 298 // Used to generate outputs like: 1100-0 1110-1-1100-1-1111-0 110-0. 299 bool use_delimiter_ = false; 300 }; 301 302 // Creates spv_text object containing text from |str|. 303 // The returned value is owned by the caller and needs to be destroyed with 304 // spvTextDestroy. 305 spv_text CreateSpvText(const std::string& str) { 306 spv_text out = new spv_text_t(); 307 assert(out); 308 char* cstr = new char[str.length() + 1]; 309 assert(cstr); 310 std::strncpy(cstr, str.c_str(), str.length()); 311 cstr[str.length()] = '\0'; 312 out->str = cstr; 313 out->length = str.length(); 314 return out; 315 } 316 317 // Base class for MARK-V encoder and decoder. Contains common functionality 318 // such as: 319 // - Validator connection and validation state. 320 // - SPIR-V grammar and helper functions. 321 class MarkvCodecBase { 322 public: 323 virtual ~MarkvCodecBase() { 324 spvValidatorOptionsDestroy(validator_options_); 325 } 326 327 MarkvCodecBase() = delete; 328 329 void SetModel(const MarkvModel* model) { 330 model_ = model; 331 } 332 333 protected: 334 struct MarkvHeader { 335 MarkvHeader() { 336 magic_number = kMarkvMagicNumber; 337 markv_version = GetMarkvVersion(); 338 markv_model = 0; 339 markv_length_in_bits = 0; 340 spirv_version = 0; 341 spirv_generator = 0; 342 } 343 344 uint32_t magic_number; 345 uint32_t markv_version; 346 // Magic number to identify or verify MarkvModel used for encoding. 347 uint32_t markv_model; 348 uint32_t markv_length_in_bits; 349 uint32_t spirv_version; 350 uint32_t spirv_generator; 351 }; 352 353 explicit MarkvCodecBase(spv_const_context context, 354 spv_validator_options validator_options) 355 : validator_options_(validator_options), 356 vstate_(context, validator_options_), grammar_(context), 357 model_(GetDefaultModel()) {} 358 359 // Validates a single instruction and updates validation state of the module. 360 spv_result_t UpdateValidationState(const spv_parsed_instruction_t& inst) { 361 return ValidateInstructionAndUpdateValidationState(&vstate_, &inst); 362 } 363 364 // Returns the current instruction (the one last processed by the validator). 365 const Instruction& GetCurrentInstruction() const { 366 return vstate_.ordered_instructions().back(); 367 } 368 369 spv_validator_options validator_options_; 370 ValidationState_t vstate_; 371 const libspirv::AssemblyGrammar grammar_; 372 MarkvHeader header_; 373 const MarkvModel* model_; 374 375 // Move-to-front list of all ids. 376 // TODO(atgoo (at) github.com) Consider a better move-to-front implementation. 377 std::list<uint32_t> move_to_front_ids_; 378 }; 379 380 // SPIR-V to MARK-V encoder. Exposes functions EncodeHeader and 381 // EncodeInstruction which can be used as callback by spvBinaryParse. 382 // Encoded binary is written to an internally maintained bitstream. 383 // After the last instruction is encoded, the resulting MARK-V binary can be 384 // acquired by calling GetMarkvBinary(). 385 // The encoder uses SPIR-V validator to keep internal state, therefore 386 // SPIR-V binary needs to be able to pass validator checks. 387 // CreateCommentsLogger() can be used to enable the encoder to write comments 388 // on how encoding was done, which can later be accessed with GetComments(). 389 class MarkvEncoder : public MarkvCodecBase { 390 public: 391 MarkvEncoder(spv_const_context context, 392 spv_const_markv_encoder_options options) 393 : MarkvCodecBase(context, GetValidatorOptions(options)), 394 options_(options) { 395 (void) options_; 396 } 397 398 // Writes data from SPIR-V header to MARK-V header. 399 spv_result_t EncodeHeader( 400 spv_endianness_t /* endian */, uint32_t /* magic */, 401 uint32_t version, uint32_t generator, uint32_t id_bound, 402 uint32_t /* schema */) { 403 vstate_.setIdBound(id_bound); 404 header_.spirv_version = version; 405 header_.spirv_generator = generator; 406 return SPV_SUCCESS; 407 } 408 409 // Encodes SPIR-V instruction to MARK-V and writes to bit stream. 410 // Operation can fail if the instruction fails to pass the validator or if 411 // the encoder stubmles on something unexpected. 412 spv_result_t EncodeInstruction(const spv_parsed_instruction_t& inst); 413 414 // Concatenates MARK-V header and the bit stream with encoded instructions 415 // into a single buffer and returns it as spv_markv_binary. The returned 416 // value is owned by the caller and needs to be destroyed with 417 // spvMarkvBinaryDestroy(). 418 spv_markv_binary GetMarkvBinary() { 419 header_.markv_length_in_bits = 420 static_cast<uint32_t>(sizeof(header_) * 8 + writer_.GetNumBits()); 421 const size_t num_bytes = sizeof(header_) + writer_.GetDataSizeBytes(); 422 423 spv_markv_binary markv_binary = new spv_markv_binary_t(); 424 markv_binary->data = new uint8_t[num_bytes]; 425 markv_binary->length = num_bytes; 426 assert(writer_.GetData()); 427 std::memcpy(markv_binary->data, &header_, sizeof(header_)); 428 std::memcpy(markv_binary->data + sizeof(header_), 429 writer_.GetData(), writer_.GetDataSizeBytes()); 430 return markv_binary; 431 } 432 433 // Creates an internal logger which writes comments on the encoding process. 434 // Output can later be accessed with GetComments(). 435 void CreateCommentsLogger() { 436 logger_.reset(new CommentLogger()); 437 writer_.SetCallback([this](const std::string& str){ 438 logger_->AppendBitSequence(str); 439 }); 440 } 441 442 // Optionally adds disassembly to the comments. 443 // Disassembly should contain all instructions in the module separated by 444 // \n, and no header. 445 void SetDisassembly(std::string&& disassembly) { 446 disassembly_.reset(new std::stringstream(std::move(disassembly))); 447 } 448 449 // Extracts the next instruction line from the disassembly and logs it. 450 void LogDisassemblyInstruction() { 451 if (logger_ && disassembly_) { 452 std::string line; 453 std::getline(*disassembly_, line, '\n'); 454 logger_->AppendTextNewLine(line); 455 } 456 } 457 458 // Extracts the text from the comment logger. 459 std::string GetComments() const { 460 if (!logger_) 461 return ""; 462 return logger_->GetText(); 463 } 464 465 private: 466 // Creates and returns validator options. Return value owned by the caller. 467 static spv_validator_options GetValidatorOptions( 468 spv_const_markv_encoder_options) { 469 return spvValidatorOptionsCreate(); 470 } 471 472 // Writes a single word to bit stream. |type| determines if the word is 473 // encoded and how. 474 void EncodeOperandWord(spv_operand_type_t type, uint32_t word) { 475 const size_t chunk_length = 476 GetOperandVariableWidthChunkLength(type); 477 if (chunk_length) { 478 writer_.WriteVariableWidthU32(word, chunk_length); 479 } else { 480 writer_.WriteUnencoded(word); 481 } 482 } 483 484 // Returns id index and updates move-to-front. 485 // Index is uint16 as SPIR-V module is guaranteed to have no more than 65535 486 // instructions. 487 uint16_t GetIdIndex(uint32_t id) { 488 if (all_known_ids_.count(id)) { 489 uint16_t index = 0; 490 for (auto it = move_to_front_ids_.begin(); 491 it != move_to_front_ids_.end(); ++it) { 492 if (*it == id) { 493 if (index != 0) { 494 move_to_front_ids_.erase(it); 495 move_to_front_ids_.push_front(id); 496 } 497 return index; 498 } 499 ++index; 500 } 501 assert(0 && "Id not found in move_to_front_ids_"); 502 return 0; 503 } else { 504 all_known_ids_.insert(id); 505 move_to_front_ids_.push_front(id); 506 return static_cast<uint16_t>(move_to_front_ids_.size() - 1); 507 } 508 } 509 510 void AddByteBreakIfAgreed() { 511 if (!ShouldByteBreak(writer_.GetNumBits())) 512 return; 513 514 if (logger_) { 515 logger_->AppendWhitespaces(kCommentNumWhitespaces); 516 logger_->AppendText("ByteBreak:"); 517 } 518 519 writer_.WriteBits(0, GetNumBitsToNextByte(writer_.GetNumBits())); 520 } 521 522 // Encodes a literal number operand and writes it to the bit stream. 523 void EncodeLiteralNumber(const Instruction& instruction, 524 const spv_parsed_operand_t& operand); 525 526 spv_const_markv_encoder_options options_; 527 528 // Bit stream where encoded instructions are written. 529 BitWriterWord64 writer_; 530 531 // If not nullptr, encoder will write comments. 532 std::unique_ptr<CommentLogger> logger_; 533 534 // If not nullptr, disassembled instruction lines will be written to comments. 535 // Format: \n separated instruction lines, no header. 536 std::unique_ptr<std::stringstream> disassembly_; 537 538 // All ids which were previosly encountered in the module. 539 std::unordered_set<uint32_t> all_known_ids_; 540 }; 541 542 // Decodes MARK-V buffers written by MarkvEncoder. 543 class MarkvDecoder : public MarkvCodecBase { 544 public: 545 MarkvDecoder(spv_const_context context, 546 const uint8_t* markv_data, 547 size_t markv_size_bytes, 548 spv_const_markv_decoder_options options) 549 : MarkvCodecBase(context, GetValidatorOptions(options)), 550 options_(options), reader_(markv_data, markv_size_bytes) { 551 (void) options_; 552 vstate_.setIdBound(1); 553 parsed_operands_.reserve(25); 554 } 555 556 // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|. 557 // Can be called only once. Fails if data of wrong format or ends prematurely, 558 // of if validation fails. 559 spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary); 560 561 private: 562 // Describes the format of a typed literal number. 563 struct NumberType { 564 spv_number_kind_t type; 565 uint32_t bit_width; 566 }; 567 568 // Creates and returns validator options. Return value owned by the caller. 569 static spv_validator_options GetValidatorOptions( 570 spv_const_markv_decoder_options) { 571 return spvValidatorOptionsCreate(); 572 } 573 574 // Reads a single word from bit stream. |type| determines if the word needs 575 // to be decoded and how. Returns false if read fails. 576 bool DecodeOperandWord(spv_operand_type_t type, uint32_t* word) { 577 const size_t chunk_length = GetOperandVariableWidthChunkLength(type); 578 if (chunk_length) { 579 return reader_.ReadVariableWidthU32(word, chunk_length); 580 } else { 581 return reader_.ReadUnencoded(word); 582 } 583 } 584 585 // Fetches the id from the move-to-front list and moves it to front. 586 uint32_t GetIdAndMoveToFront(uint16_t index) { 587 if (index >= move_to_front_ids_.size()) { 588 // Issue new id. 589 const uint32_t id = vstate_.getIdBound(); 590 move_to_front_ids_.push_front(id); 591 vstate_.setIdBound(id + 1); 592 return id; 593 } else { 594 if (index == 0) 595 return move_to_front_ids_.front(); 596 597 // Iterate to index. 598 auto it = move_to_front_ids_.begin(); 599 for (size_t i = 0; i < index; ++i) 600 ++it; 601 const uint32_t id = *it; 602 move_to_front_ids_.erase(it); 603 move_to_front_ids_.push_front(id); 604 return id; 605 } 606 } 607 608 // Decodes id index and fetches the id from move-to-front list. 609 bool DecodeId(uint32_t* id) { 610 uint16_t index = 0; 611 if (!reader_.ReadVariableWidthU16(&index, model_->id_index_chunk_length())) 612 return false; 613 614 *id = GetIdAndMoveToFront(index); 615 return true; 616 } 617 618 bool ReadToByteBreakIfAgreed() { 619 if (!ShouldByteBreak(reader_.GetNumReadBits())) 620 return true; 621 622 uint64_t bits = 0; 623 if (!reader_.ReadBits(&bits, 624 GetNumBitsToNextByte(reader_.GetNumReadBits()))) 625 return false; 626 627 if (bits != 0) 628 return false; 629 630 return true; 631 } 632 633 // Reads a literal number as it is described in |operand| from the bit stream, 634 // decodes and writes it to spirv_. 635 spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand); 636 637 // Reads instruction from bit stream, decodes and validates it. 638 // Decoded instruction is valid until the next call of DecodeInstruction(). 639 spv_result_t DecodeInstruction(spv_parsed_instruction_t* inst); 640 641 // Read operand from the stream decodes and validates it. 642 spv_result_t DecodeOperand(size_t instruction_offset, size_t operand_offset, 643 spv_parsed_instruction_t* inst, 644 const spv_operand_type_t type, 645 spv_operand_pattern_t* expected_operands, 646 bool read_result_id); 647 648 // Records the numeric type for an operand according to the type information 649 // associated with the given non-zero type Id. This can fail if the type Id 650 // is not a type Id, or if the type Id does not reference a scalar numeric 651 // type. On success, return SPV_SUCCESS and populates the num_words, 652 // number_kind, and number_bit_width fields of parsed_operand. 653 spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand, 654 uint32_t type_id); 655 656 // Records the number type for the given instruction, if that 657 // instruction generates a type. For types that aren't scalar numbers, 658 // record something with number kind SPV_NUMBER_NONE. 659 void RecordNumberType(const spv_parsed_instruction_t& inst); 660 661 spv_const_markv_decoder_options options_; 662 663 // Temporary sink where decoded SPIR-V words are written. Once it contains the 664 // entire module, the container is moved and returned. 665 std::vector<uint32_t> spirv_; 666 667 // Bit stream containing encoded data. 668 BitReaderWord64 reader_; 669 670 // Temporary storage for operands of the currently parsed instruction. 671 // Valid until next DecodeInstruction call. 672 std::vector<spv_parsed_operand_t> parsed_operands_; 673 674 // Maps a result ID to its type ID. By convention: 675 // - a result ID that is a type definition maps to itself. 676 // - a result ID without a type maps to 0. (E.g. for OpLabel) 677 std::unordered_map<uint32_t, uint32_t> id_to_type_id_; 678 // Maps a type ID to its number type description. 679 std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_; 680 // Maps an ExtInstImport id to the extended instruction type. 681 std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_; 682 }; 683 684 void MarkvEncoder::EncodeLiteralNumber(const Instruction& instruction, 685 const spv_parsed_operand_t& operand) { 686 if (operand.number_bit_width == 32) { 687 const uint32_t word = instruction.word(operand.offset); 688 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { 689 writer_.WriteVariableWidthU32(word, model_->u32_chunk_length()); 690 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { 691 int32_t val = 0; 692 std::memcpy(&val, &word, 4); 693 writer_.WriteVariableWidthS32(val, model_->s32_chunk_length(), 694 model_->s32_block_exponent()); 695 } else if (operand.number_kind == SPV_NUMBER_FLOATING) { 696 writer_.WriteUnencoded(word); 697 } else { 698 assert(0); 699 } 700 } else if (operand.number_bit_width == 16) { 701 const uint16_t word = 702 static_cast<uint16_t>(instruction.word(operand.offset)); 703 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { 704 writer_.WriteVariableWidthU16(word, model_->u16_chunk_length()); 705 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { 706 int16_t val = 0; 707 std::memcpy(&val, &word, 2); 708 writer_.WriteVariableWidthS16(val, model_->s16_chunk_length(), 709 model_->s16_block_exponent()); 710 } else if (operand.number_kind == SPV_NUMBER_FLOATING) { 711 // TODO(atgoo (at) github.com) Write only 16 bits. 712 writer_.WriteUnencoded(word); 713 } else { 714 assert(0); 715 } 716 } else { 717 assert(operand.number_bit_width == 64); 718 const uint64_t word = 719 uint64_t(instruction.word(operand.offset)) | 720 (uint64_t(instruction.word(operand.offset + 1)) << 32); 721 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { 722 writer_.WriteVariableWidthU64(word, model_->u64_chunk_length()); 723 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { 724 int64_t val = 0; 725 std::memcpy(&val, &word, 8); 726 writer_.WriteVariableWidthS64(val, model_->s64_chunk_length(), 727 model_->s64_block_exponent()); 728 } else if (operand.number_kind == SPV_NUMBER_FLOATING) { 729 writer_.WriteUnencoded(word); 730 } else { 731 assert(0); 732 } 733 } 734 } 735 736 spv_result_t MarkvEncoder::EncodeInstruction( 737 const spv_parsed_instruction_t& inst) { 738 const spv_result_t validation_result = UpdateValidationState(inst); 739 if (validation_result != SPV_SUCCESS) 740 return validation_result; 741 742 bool result_id_was_forward_declared = false; 743 if (all_known_ids_.count(inst.result_id)) { 744 // Result id of the instruction was forward declared. 745 // Write a service opcode to signal this to the decoder. 746 writer_.WriteVariableWidthU32(kMarkvOpNextInstructionEncodesResultId, 747 model_->opcode_chunk_length()); 748 result_id_was_forward_declared = true; 749 } 750 751 const Instruction& instruction = GetCurrentInstruction(); 752 const auto& operands = instruction.operands(); 753 754 LogDisassemblyInstruction(); 755 756 // Write opcode. 757 writer_.WriteVariableWidthU32(inst.opcode, model_->opcode_chunk_length()); 758 759 if (!OpcodeHasFixedNumberOfOperands(SpvOp(inst.opcode))) { 760 // If the opcode has a variable number of operands, encode the number of 761 // operands with the instruction. 762 763 if (logger_) 764 logger_->AppendWhitespaces(kCommentNumWhitespaces); 765 766 writer_.WriteVariableWidthU16(inst.num_operands, 767 model_->num_operands_chunk_length()); 768 } 769 770 // Write operands. 771 for (const auto& operand : operands) { 772 if (operand.type == SPV_OPERAND_TYPE_RESULT_ID && 773 !result_id_was_forward_declared) { 774 // Register the id, but don't encode it. 775 GetIdIndex(instruction.word(operand.offset)); 776 continue; 777 } 778 779 if (logger_) 780 logger_->AppendWhitespaces(kCommentNumWhitespaces); 781 782 if (operand.type == SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER) { 783 EncodeLiteralNumber(instruction, operand); 784 } else if (operand.type == SPV_OPERAND_TYPE_LITERAL_STRING) { 785 const char* src = 786 reinterpret_cast<const char*>(&instruction.words()[operand.offset]); 787 const size_t length = spv_strnlen_s(src, operand.num_words * 4); 788 if (length == operand.num_words * 4) 789 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 790 << "Failed to find terminal character of literal string"; 791 for (size_t i = 0; i < length + 1; ++i) 792 writer_.WriteUnencoded(src[i]); 793 } else if (spvIsIdType(operand.type)) { 794 const uint16_t id_index = GetIdIndex(instruction.word(operand.offset)); 795 writer_.WriteVariableWidthU16(id_index, model_->id_index_chunk_length()); 796 } else { 797 for (int i = 0; i < operand.num_words; ++i) { 798 const uint32_t word = instruction.word(operand.offset + i); 799 EncodeOperandWord(operand.type, word); 800 } 801 } 802 } 803 804 AddByteBreakIfAgreed(); 805 806 if (logger_) { 807 logger_->NewLine(); 808 logger_->NewLine(); 809 } 810 811 return SPV_SUCCESS; 812 } 813 814 spv_result_t MarkvDecoder::DecodeLiteralNumber( 815 const spv_parsed_operand_t& operand) { 816 if (operand.number_bit_width == 32) { 817 uint32_t word = 0; 818 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { 819 if (!reader_.ReadVariableWidthU32(&word, model_->u32_chunk_length())) 820 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 821 << "Failed to read literal U32"; 822 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { 823 int32_t val = 0; 824 if (!reader_.ReadVariableWidthS32(&val, model_->s32_chunk_length(), 825 model_->s32_block_exponent())) 826 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 827 << "Failed to read literal S32"; 828 std::memcpy(&word, &val, 4); 829 } else if (operand.number_kind == SPV_NUMBER_FLOATING) { 830 if (!reader_.ReadUnencoded(&word)) 831 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 832 << "Failed to read literal F32"; 833 } else { 834 assert(0); 835 } 836 spirv_.push_back(word); 837 } else if (operand.number_bit_width == 16) { 838 uint32_t word = 0; 839 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { 840 uint16_t val = 0; 841 if (!reader_.ReadVariableWidthU16(&val, model_->u16_chunk_length())) 842 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 843 << "Failed to read literal U16"; 844 word = val; 845 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { 846 int16_t val = 0; 847 if (!reader_.ReadVariableWidthS16(&val, model_->s16_chunk_length(), 848 model_->s16_block_exponent())) 849 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 850 << "Failed to read literal S16"; 851 // Int16 is stored as int32 in SPIR-V, not as bits. 852 int32_t val32 = val; 853 std::memcpy(&word, &val32, 4); 854 } else if (operand.number_kind == SPV_NUMBER_FLOATING) { 855 uint16_t word16 = 0; 856 if (!reader_.ReadUnencoded(&word16)) 857 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 858 << "Failed to read literal F16"; 859 word = word16; 860 } else { 861 assert(0); 862 } 863 spirv_.push_back(word); 864 } else { 865 assert(operand.number_bit_width == 64); 866 uint64_t word = 0; 867 if (operand.number_kind == SPV_NUMBER_UNSIGNED_INT) { 868 if (!reader_.ReadVariableWidthU64(&word, model_->u64_chunk_length())) 869 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 870 << "Failed to read literal U64"; 871 } else if (operand.number_kind == SPV_NUMBER_SIGNED_INT) { 872 int64_t val = 0; 873 if (!reader_.ReadVariableWidthS64(&val, model_->s64_chunk_length(), 874 model_->s64_block_exponent())) 875 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 876 << "Failed to read literal S64"; 877 std::memcpy(&word, &val, 8); 878 } else if (operand.number_kind == SPV_NUMBER_FLOATING) { 879 if (!reader_.ReadUnencoded(&word)) 880 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 881 << "Failed to read literal F64"; 882 } else { 883 assert(0); 884 } 885 spirv_.push_back(static_cast<uint32_t>(word)); 886 spirv_.push_back(static_cast<uint32_t>(word >> 32)); 887 } 888 return SPV_SUCCESS; 889 } 890 891 spv_result_t MarkvDecoder::DecodeModule(std::vector<uint32_t>* spirv_binary) { 892 const bool header_read_success = 893 reader_.ReadUnencoded(&header_.magic_number) && 894 reader_.ReadUnencoded(&header_.markv_version) && 895 reader_.ReadUnencoded(&header_.markv_model) && 896 reader_.ReadUnencoded(&header_.markv_length_in_bits) && 897 reader_.ReadUnencoded(&header_.spirv_version) && 898 reader_.ReadUnencoded(&header_.spirv_generator); 899 900 if (!header_read_success) 901 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 902 << "Unable to read MARK-V header"; 903 904 assert(header_.magic_number == kMarkvMagicNumber); 905 assert(header_.markv_length_in_bits > 0); 906 907 if (header_.magic_number != kMarkvMagicNumber) 908 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 909 << "MARK-V binary has incorrect magic number"; 910 911 // TODO(atgoo (at) github.com): Print version strings. 912 if (header_.markv_version != GetMarkvVersion()) 913 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 914 << "MARK-V binary and the codec have different versions"; 915 916 spirv_.reserve(header_.markv_length_in_bits / 2); // Heuristic. 917 spirv_.resize(5, 0); 918 spirv_[0] = kSpirvMagicNumber; 919 spirv_[1] = header_.spirv_version; 920 spirv_[2] = header_.spirv_generator; 921 922 while (reader_.GetNumReadBits() < header_.markv_length_in_bits) { 923 spv_parsed_instruction_t inst = {}; 924 const spv_result_t decode_result = DecodeInstruction(&inst); 925 if (decode_result != SPV_SUCCESS) 926 return decode_result; 927 928 const spv_result_t validation_result = UpdateValidationState(inst); 929 if (validation_result != SPV_SUCCESS) 930 return validation_result; 931 } 932 933 934 if (reader_.GetNumReadBits() != header_.markv_length_in_bits || 935 !reader_.OnlyZeroesLeft()) { 936 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 937 << "MARK-V binary has wrong stated bit length " 938 << reader_.GetNumReadBits() << " " << header_.markv_length_in_bits; 939 } 940 941 // Decoding of the module is finished, validation state should have correct 942 // id bound. 943 spirv_[3] = vstate_.getIdBound(); 944 945 *spirv_binary = std::move(spirv_); 946 return SPV_SUCCESS; 947 } 948 949 // TODO(atgoo (at) github.com): The implementation borrows heavily from 950 // Parser::parseOperand. 951 // Consider coupling them together in some way once MARK-V codec is more mature. 952 // For now it's better to keep the code independent for experimentation 953 // purposes. 954 spv_result_t MarkvDecoder::DecodeOperand( 955 size_t instruction_offset, size_t operand_offset, 956 spv_parsed_instruction_t* inst, const spv_operand_type_t type, 957 spv_operand_pattern_t* expected_operands, 958 bool read_result_id) { 959 const SpvOp opcode = static_cast<SpvOp>(inst->opcode); 960 961 spv_parsed_operand_t parsed_operand; 962 memset(&parsed_operand, 0, sizeof(parsed_operand)); 963 964 assert((operand_offset >> 16) == 0); 965 parsed_operand.offset = static_cast<uint16_t>(operand_offset); 966 parsed_operand.type = type; 967 968 // Set default values, may be updated later. 969 parsed_operand.number_kind = SPV_NUMBER_NONE; 970 parsed_operand.number_bit_width = 0; 971 972 const size_t first_word_index = spirv_.size(); 973 974 switch (type) { 975 case SPV_OPERAND_TYPE_TYPE_ID: { 976 if (!DecodeId(&inst->type_id)) { 977 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 978 << "Failed to read type_id"; 979 } 980 981 if (inst->type_id == 0) 982 return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded type_id is 0"; 983 984 spirv_.push_back(inst->type_id); 985 vstate_.setIdBound(std::max(vstate_.getIdBound(), inst->type_id + 1)); 986 break; 987 } 988 989 case SPV_OPERAND_TYPE_RESULT_ID: { 990 if (read_result_id) { 991 if (!DecodeId(&inst->result_id)) 992 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 993 << "Failed to read result_id"; 994 } else { 995 inst->result_id = vstate_.getIdBound(); 996 vstate_.setIdBound(inst->result_id + 1); 997 move_to_front_ids_.push_front(inst->result_id); 998 } 999 1000 spirv_.push_back(inst->result_id); 1001 1002 // Save the result ID to type ID mapping. 1003 // In the grammar, type ID always appears before result ID. 1004 // A regular value maps to its type. Some instructions (e.g. OpLabel) 1005 // have no type Id, and will map to 0. The result Id for a 1006 // type-generating instruction (e.g. OpTypeInt) maps to itself. 1007 auto insertion_result = id_to_type_id_.emplace( 1008 inst->result_id, 1009 spvOpcodeGeneratesType(opcode) ? inst->result_id : inst->type_id); 1010 if(!insertion_result.second) { 1011 return vstate_.diag(SPV_ERROR_INVALID_ID) 1012 << "Unexpected behavior: id->type_id pair was already registered"; 1013 } 1014 break; 1015 } 1016 1017 case SPV_OPERAND_TYPE_ID: 1018 case SPV_OPERAND_TYPE_OPTIONAL_ID: 1019 case SPV_OPERAND_TYPE_SCOPE_ID: 1020 case SPV_OPERAND_TYPE_MEMORY_SEMANTICS_ID: { 1021 uint32_t id = 0; 1022 if (!DecodeId(&id)) 1023 return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Failed to read id"; 1024 1025 if (id == 0) 1026 return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Decoded id is 0"; 1027 1028 spirv_.push_back(id); 1029 vstate_.setIdBound(std::max(vstate_.getIdBound(), id + 1)); 1030 1031 if (type == SPV_OPERAND_TYPE_ID || type == SPV_OPERAND_TYPE_OPTIONAL_ID) { 1032 1033 parsed_operand.type = SPV_OPERAND_TYPE_ID; 1034 1035 if (opcode == SpvOpExtInst && parsed_operand.offset == 3) { 1036 // The current word is the extended instruction set id. 1037 // Set the extended instruction set type for the current instruction. 1038 auto ext_inst_type_iter = import_id_to_ext_inst_type_.find(id); 1039 if (ext_inst_type_iter == import_id_to_ext_inst_type_.end()) { 1040 return vstate_.diag(SPV_ERROR_INVALID_ID) 1041 << "OpExtInst set id " << id 1042 << " does not reference an OpExtInstImport result Id"; 1043 } 1044 inst->ext_inst_type = ext_inst_type_iter->second; 1045 } 1046 } 1047 break; 1048 } 1049 1050 case SPV_OPERAND_TYPE_EXTENSION_INSTRUCTION_NUMBER: { 1051 uint32_t word = 0; 1052 if (!DecodeOperandWord(type, &word)) 1053 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1054 << "Failed to read enum"; 1055 1056 spirv_.push_back(word); 1057 1058 assert(SpvOpExtInst == opcode); 1059 assert(inst->ext_inst_type != SPV_EXT_INST_TYPE_NONE); 1060 spv_ext_inst_desc ext_inst; 1061 if (grammar_.lookupExtInst(inst->ext_inst_type, word, &ext_inst)) 1062 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1063 << "Invalid extended instruction number: " << word; 1064 spvPushOperandTypes(ext_inst->operandTypes, expected_operands); 1065 break; 1066 } 1067 1068 case SPV_OPERAND_TYPE_LITERAL_INTEGER: 1069 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_INTEGER: { 1070 // These are regular single-word literal integer operands. 1071 // Post-parsing validation should check the range of the parsed value. 1072 parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_INTEGER; 1073 // It turns out they are always unsigned integers! 1074 parsed_operand.number_kind = SPV_NUMBER_UNSIGNED_INT; 1075 parsed_operand.number_bit_width = 32; 1076 1077 uint32_t word = 0; 1078 if (!DecodeOperandWord(type, &word)) 1079 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1080 << "Failed to read literal integer"; 1081 1082 spirv_.push_back(word); 1083 break; 1084 } 1085 1086 case SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER: 1087 case SPV_OPERAND_TYPE_OPTIONAL_TYPED_LITERAL_INTEGER: 1088 parsed_operand.type = SPV_OPERAND_TYPE_TYPED_LITERAL_NUMBER; 1089 if (opcode == SpvOpSwitch) { 1090 // The literal operands have the same type as the value 1091 // referenced by the selector Id. 1092 const uint32_t selector_id = spirv_.at(instruction_offset + 1); 1093 const auto type_id_iter = id_to_type_id_.find(selector_id); 1094 if (type_id_iter == id_to_type_id_.end() || 1095 type_id_iter->second == 0) { 1096 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1097 << "Invalid OpSwitch: selector id " << selector_id 1098 << " has no type"; 1099 } 1100 uint32_t type_id = type_id_iter->second; 1101 1102 if (selector_id == type_id) { 1103 // Recall that by convention, a result ID that is a type definition 1104 // maps to itself. 1105 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1106 << "Invalid OpSwitch: selector id " << selector_id 1107 << " is a type, not a value"; 1108 } 1109 if (auto error = SetNumericTypeInfoForType(&parsed_operand, type_id)) 1110 return error; 1111 if (parsed_operand.number_kind != SPV_NUMBER_UNSIGNED_INT && 1112 parsed_operand.number_kind != SPV_NUMBER_SIGNED_INT) { 1113 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1114 << "Invalid OpSwitch: selector id " << selector_id 1115 << " is not a scalar integer"; 1116 } 1117 } else { 1118 assert(opcode == SpvOpConstant || opcode == SpvOpSpecConstant); 1119 // The literal number type is determined by the type Id for the 1120 // constant. 1121 assert(inst->type_id); 1122 if (auto error = 1123 SetNumericTypeInfoForType(&parsed_operand, inst->type_id)) 1124 return error; 1125 } 1126 1127 if (auto error = DecodeLiteralNumber(parsed_operand)) 1128 return error; 1129 1130 break; 1131 1132 case SPV_OPERAND_TYPE_LITERAL_STRING: 1133 case SPV_OPERAND_TYPE_OPTIONAL_LITERAL_STRING: { 1134 parsed_operand.type = SPV_OPERAND_TYPE_LITERAL_STRING; 1135 std::vector<char> str; 1136 // The loop is expected to terminate once we encounter '\0' or exhaust 1137 // the bit stream. 1138 while (true) { 1139 char ch = 0; 1140 if (!reader_.ReadUnencoded(&ch)) 1141 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1142 << "Failed to read literal string"; 1143 1144 str.push_back(ch); 1145 1146 if (ch == '\0') 1147 break; 1148 } 1149 1150 while (str.size() % 4 != 0) 1151 str.push_back('\0'); 1152 1153 spirv_.resize(spirv_.size() + str.size() / 4); 1154 std::memcpy(&spirv_[first_word_index], str.data(), str.size()); 1155 1156 if (SpvOpExtInstImport == opcode) { 1157 // Record the extended instruction type for the ID for this import. 1158 // There is only one string literal argument to OpExtInstImport, 1159 // so it's sufficient to guard this just on the opcode. 1160 const spv_ext_inst_type_t ext_inst_type = 1161 spvExtInstImportTypeGet(str.data()); 1162 if (SPV_EXT_INST_TYPE_NONE == ext_inst_type) { 1163 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1164 << "Invalid extended instruction import '" << str.data() << "'"; 1165 } 1166 // We must have parsed a valid result ID. It's a condition 1167 // of the grammar, and we only accept non-zero result Ids. 1168 assert(inst->result_id); 1169 const bool inserted = import_id_to_ext_inst_type_.emplace( 1170 inst->result_id, ext_inst_type).second; 1171 (void)inserted; 1172 assert(inserted); 1173 } 1174 break; 1175 } 1176 1177 case SPV_OPERAND_TYPE_CAPABILITY: 1178 case SPV_OPERAND_TYPE_SOURCE_LANGUAGE: 1179 case SPV_OPERAND_TYPE_EXECUTION_MODEL: 1180 case SPV_OPERAND_TYPE_ADDRESSING_MODEL: 1181 case SPV_OPERAND_TYPE_MEMORY_MODEL: 1182 case SPV_OPERAND_TYPE_EXECUTION_MODE: 1183 case SPV_OPERAND_TYPE_STORAGE_CLASS: 1184 case SPV_OPERAND_TYPE_DIMENSIONALITY: 1185 case SPV_OPERAND_TYPE_SAMPLER_ADDRESSING_MODE: 1186 case SPV_OPERAND_TYPE_SAMPLER_FILTER_MODE: 1187 case SPV_OPERAND_TYPE_SAMPLER_IMAGE_FORMAT: 1188 case SPV_OPERAND_TYPE_FP_ROUNDING_MODE: 1189 case SPV_OPERAND_TYPE_LINKAGE_TYPE: 1190 case SPV_OPERAND_TYPE_ACCESS_QUALIFIER: 1191 case SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER: 1192 case SPV_OPERAND_TYPE_FUNCTION_PARAMETER_ATTRIBUTE: 1193 case SPV_OPERAND_TYPE_DECORATION: 1194 case SPV_OPERAND_TYPE_BUILT_IN: 1195 case SPV_OPERAND_TYPE_GROUP_OPERATION: 1196 case SPV_OPERAND_TYPE_KERNEL_ENQ_FLAGS: 1197 case SPV_OPERAND_TYPE_KERNEL_PROFILING_INFO: { 1198 // A single word that is a plain enum value. 1199 uint32_t word = 0; 1200 if (!DecodeOperandWord(type, &word)) 1201 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1202 << "Failed to read enum"; 1203 1204 spirv_.push_back(word); 1205 1206 // Map an optional operand type to its corresponding concrete type. 1207 if (type == SPV_OPERAND_TYPE_OPTIONAL_ACCESS_QUALIFIER) 1208 parsed_operand.type = SPV_OPERAND_TYPE_ACCESS_QUALIFIER; 1209 1210 spv_operand_desc entry; 1211 if (grammar_.lookupOperand(type, word, &entry)) { 1212 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1213 << "Invalid " 1214 << spvOperandTypeStr(parsed_operand.type) 1215 << " operand: " << word; 1216 } 1217 1218 // Prepare to accept operands to this operand, if needed. 1219 spvPushOperandTypes(entry->operandTypes, expected_operands); 1220 break; 1221 } 1222 1223 case SPV_OPERAND_TYPE_FP_FAST_MATH_MODE: 1224 case SPV_OPERAND_TYPE_FUNCTION_CONTROL: 1225 case SPV_OPERAND_TYPE_LOOP_CONTROL: 1226 case SPV_OPERAND_TYPE_IMAGE: 1227 case SPV_OPERAND_TYPE_OPTIONAL_IMAGE: 1228 case SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS: 1229 case SPV_OPERAND_TYPE_SELECTION_CONTROL: { 1230 // This operand is a mask. 1231 uint32_t word = 0; 1232 if (!DecodeOperandWord(type, &word)) 1233 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1234 << "Failed to read " << spvOperandTypeStr(type) 1235 << " for " << spvOpcodeString(SpvOp(inst->opcode)); 1236 1237 spirv_.push_back(word); 1238 1239 // Map an optional operand type to its corresponding concrete type. 1240 if (type == SPV_OPERAND_TYPE_OPTIONAL_IMAGE) 1241 parsed_operand.type = SPV_OPERAND_TYPE_IMAGE; 1242 else if (type == SPV_OPERAND_TYPE_OPTIONAL_MEMORY_ACCESS) 1243 parsed_operand.type = SPV_OPERAND_TYPE_MEMORY_ACCESS; 1244 1245 // Check validity of set mask bits. Also prepare for operands for those 1246 // masks if they have any. To get operand order correct, scan from 1247 // MSB to LSB since we can only prepend operands to a pattern. 1248 // The only case in the grammar where you have more than one mask bit 1249 // having an operand is for image operands. See SPIR-V 3.14 Image 1250 // Operands. 1251 uint32_t remaining_word = word; 1252 for (uint32_t mask = (1u << 31); remaining_word; mask >>= 1) { 1253 if (remaining_word & mask) { 1254 spv_operand_desc entry; 1255 if (grammar_.lookupOperand(type, mask, &entry)) { 1256 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1257 << "Invalid " << spvOperandTypeStr(parsed_operand.type) 1258 << " operand: " << word << " has invalid mask component " 1259 << mask; 1260 } 1261 remaining_word ^= mask; 1262 spvPushOperandTypes(entry->operandTypes, expected_operands); 1263 } 1264 } 1265 if (word == 0) { 1266 // An all-zeroes mask *might* also be valid. 1267 spv_operand_desc entry; 1268 if (SPV_SUCCESS == grammar_.lookupOperand(type, 0, &entry)) { 1269 // Prepare for its operands, if any. 1270 spvPushOperandTypes(entry->operandTypes, expected_operands); 1271 } 1272 } 1273 break; 1274 } 1275 default: 1276 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1277 << "Internal error: Unhandled operand type: " << type; 1278 } 1279 1280 parsed_operand.num_words = uint16_t(spirv_.size() - first_word_index); 1281 1282 assert(int(SPV_OPERAND_TYPE_FIRST_CONCRETE_TYPE) <= int(parsed_operand.type)); 1283 assert(int(SPV_OPERAND_TYPE_LAST_CONCRETE_TYPE) >= int(parsed_operand.type)); 1284 1285 parsed_operands_.push_back(parsed_operand); 1286 1287 return SPV_SUCCESS; 1288 } 1289 1290 spv_result_t MarkvDecoder::DecodeInstruction(spv_parsed_instruction_t* inst) { 1291 parsed_operands_.clear(); 1292 const size_t instruction_offset = spirv_.size(); 1293 1294 bool read_result_id = false; 1295 1296 while (true) { 1297 uint32_t word = 0; 1298 if (!reader_.ReadVariableWidthU32(&word, 1299 model_->opcode_chunk_length())) { 1300 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1301 << "Failed to read opcode of instruction"; 1302 } 1303 1304 if (word >= kMarkvFirstOpcode) { 1305 if (word == kMarkvOpNextInstructionEncodesResultId) { 1306 read_result_id = true; 1307 } else { 1308 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1309 << "Encountered unknown MARK-V opcode"; 1310 } 1311 } else { 1312 inst->opcode = static_cast<uint16_t>(word); 1313 break; 1314 } 1315 } 1316 1317 const SpvOp opcode = static_cast<SpvOp>(inst->opcode); 1318 1319 // Opcode/num_words placeholder, the word will be filled in later. 1320 spirv_.push_back(0); 1321 1322 spv_opcode_desc opcode_desc; 1323 if (grammar_.lookupOpcode(opcode, &opcode_desc) 1324 != SPV_SUCCESS) { 1325 return vstate_.diag(SPV_ERROR_INVALID_BINARY) << "Invalid opcode"; 1326 } 1327 1328 spv_operand_pattern_t expected_operands; 1329 expected_operands.reserve(opcode_desc->numTypes); 1330 for (auto i = 0; i < opcode_desc->numTypes; i++) 1331 expected_operands.push_back(opcode_desc->operandTypes[opcode_desc->numTypes - i - 1]); 1332 1333 if (!OpcodeHasFixedNumberOfOperands(opcode)) { 1334 if (!reader_.ReadVariableWidthU16(&inst->num_operands, 1335 model_->num_operands_chunk_length())) 1336 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1337 << "Failed to read num_operands of instruction"; 1338 } else { 1339 inst->num_operands = static_cast<uint16_t>(expected_operands.size()); 1340 } 1341 1342 for (size_t operand_index = 0; 1343 operand_index < static_cast<size_t>(inst->num_operands); 1344 ++operand_index) { 1345 assert(!expected_operands.empty()); 1346 const spv_operand_type_t type = 1347 spvTakeFirstMatchableOperand(&expected_operands); 1348 1349 const size_t operand_offset = spirv_.size() - instruction_offset; 1350 1351 const spv_result_t decode_result = 1352 DecodeOperand(instruction_offset, operand_offset, inst, type, 1353 &expected_operands, read_result_id); 1354 1355 if (decode_result != SPV_SUCCESS) 1356 return decode_result; 1357 } 1358 1359 assert(inst->num_operands == parsed_operands_.size()); 1360 1361 // Only valid while spirv_ and parsed_operands_ remain unchanged. 1362 inst->words = &spirv_[instruction_offset]; 1363 inst->operands = parsed_operands_.empty() ? nullptr : parsed_operands_.data(); 1364 inst->num_words = static_cast<uint16_t>(spirv_.size() - instruction_offset); 1365 spirv_[instruction_offset] = 1366 spvOpcodeMake(inst->num_words, SpvOp(inst->opcode)); 1367 1368 assert(inst->num_words == std::accumulate( 1369 parsed_operands_.begin(), parsed_operands_.end(), 1, 1370 [](int num_words, const spv_parsed_operand_t& operand) { 1371 return num_words += operand.num_words; 1372 }) && "num_words in instruction doesn't correspond to the sum of num_words" 1373 "in the operands"); 1374 1375 RecordNumberType(*inst); 1376 1377 if (!ReadToByteBreakIfAgreed()) 1378 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1379 << "Failed to read to byte break"; 1380 1381 return SPV_SUCCESS; 1382 } 1383 1384 spv_result_t MarkvDecoder::SetNumericTypeInfoForType( 1385 spv_parsed_operand_t* parsed_operand, uint32_t type_id) { 1386 assert(type_id != 0); 1387 auto type_info_iter = type_id_to_number_type_info_.find(type_id); 1388 if (type_info_iter == type_id_to_number_type_info_.end()) { 1389 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1390 << "Type Id " << type_id << " is not a type"; 1391 } 1392 1393 const NumberType& info = type_info_iter->second; 1394 if (info.type == SPV_NUMBER_NONE) { 1395 // This is a valid type, but for something other than a scalar number. 1396 return vstate_.diag(SPV_ERROR_INVALID_BINARY) 1397 << "Type Id " << type_id << " is not a scalar numeric type"; 1398 } 1399 1400 parsed_operand->number_kind = info.type; 1401 parsed_operand->number_bit_width = info.bit_width; 1402 // Round up the word count. 1403 parsed_operand->num_words = static_cast<uint16_t>((info.bit_width + 31) / 32); 1404 return SPV_SUCCESS; 1405 } 1406 1407 void MarkvDecoder::RecordNumberType(const spv_parsed_instruction_t& inst) { 1408 const SpvOp opcode = static_cast<SpvOp>(inst.opcode); 1409 if (spvOpcodeGeneratesType(opcode)) { 1410 NumberType info = {SPV_NUMBER_NONE, 0}; 1411 if (SpvOpTypeInt == opcode) { 1412 info.bit_width = inst.words[inst.operands[1].offset]; 1413 info.type = inst.words[inst.operands[2].offset] ? 1414 SPV_NUMBER_SIGNED_INT : SPV_NUMBER_UNSIGNED_INT; 1415 } else if (SpvOpTypeFloat == opcode) { 1416 info.bit_width = inst.words[inst.operands[1].offset]; 1417 info.type = SPV_NUMBER_FLOATING; 1418 } 1419 // The *result* Id of a type generating instruction is the type Id. 1420 type_id_to_number_type_info_[inst.result_id] = info; 1421 } 1422 } 1423 1424 spv_result_t EncodeHeader( 1425 void* user_data, spv_endianness_t endian, uint32_t magic, 1426 uint32_t version, uint32_t generator, uint32_t id_bound, 1427 uint32_t schema) { 1428 MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data); 1429 return encoder->EncodeHeader( 1430 endian, magic, version, generator, id_bound, schema); 1431 } 1432 1433 spv_result_t EncodeInstruction( 1434 void* user_data, const spv_parsed_instruction_t* inst) { 1435 MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data); 1436 return encoder->EncodeInstruction(*inst); 1437 } 1438 1439 } // namespace 1440 1441 spv_result_t spvSpirvToMarkv(spv_const_context context, 1442 const uint32_t* spirv_words, 1443 const size_t spirv_num_words, 1444 spv_const_markv_encoder_options options, 1445 spv_markv_binary* markv_binary, 1446 spv_text* comments, spv_diagnostic* diagnostic) { 1447 spv_context_t hijack_context = *context; 1448 if (diagnostic) { 1449 *diagnostic = nullptr; 1450 libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); 1451 } 1452 1453 spv_const_binary_t spirv_binary = {spirv_words, spirv_num_words}; 1454 1455 spv_endianness_t endian; 1456 spv_position_t position = {}; 1457 if (spvBinaryEndianness(&spirv_binary, &endian)) { 1458 return libspirv::DiagnosticStream(position, hijack_context.consumer, 1459 SPV_ERROR_INVALID_BINARY) 1460 << "Invalid SPIR-V magic number."; 1461 } 1462 1463 spv_header_t header; 1464 if (spvBinaryHeaderGet(&spirv_binary, endian, &header)) { 1465 return libspirv::DiagnosticStream(position, hijack_context.consumer, 1466 SPV_ERROR_INVALID_BINARY) 1467 << "Invalid SPIR-V header."; 1468 } 1469 1470 MarkvEncoder encoder(&hijack_context, options); 1471 1472 if (comments) { 1473 encoder.CreateCommentsLogger(); 1474 1475 spv_text text = nullptr; 1476 if (spvBinaryToText(&hijack_context, spirv_words, spirv_num_words, 1477 SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text, nullptr) 1478 != SPV_SUCCESS) { 1479 return libspirv::DiagnosticStream(position, hijack_context.consumer, 1480 SPV_ERROR_INVALID_BINARY) 1481 << "Failed to disassemble SPIR-V binary."; 1482 } 1483 assert(text); 1484 encoder.SetDisassembly(std::string(text->str, text->length)); 1485 spvTextDestroy(text); 1486 } 1487 1488 if (spvBinaryParse( 1489 &hijack_context, &encoder, spirv_words, spirv_num_words, EncodeHeader, 1490 EncodeInstruction, diagnostic) != SPV_SUCCESS) { 1491 return libspirv::DiagnosticStream(position, hijack_context.consumer, 1492 SPV_ERROR_INVALID_BINARY) 1493 << "Unable to encode to MARK-V."; 1494 } 1495 1496 if (comments) 1497 *comments = CreateSpvText(encoder.GetComments()); 1498 1499 *markv_binary = encoder.GetMarkvBinary(); 1500 return SPV_SUCCESS; 1501 } 1502 1503 spv_result_t spvMarkvToSpirv(spv_const_context context, 1504 const uint8_t* markv_data, 1505 size_t markv_size_bytes, 1506 spv_const_markv_decoder_options options, 1507 spv_binary* spirv_binary, 1508 spv_text* /* comments */, spv_diagnostic* diagnostic) { 1509 spv_position_t position = {}; 1510 spv_context_t hijack_context = *context; 1511 if (diagnostic) { 1512 *diagnostic = nullptr; 1513 libspirv::UseDiagnosticAsMessageConsumer(&hijack_context, diagnostic); 1514 } 1515 1516 MarkvDecoder decoder(&hijack_context, markv_data, markv_size_bytes, options); 1517 1518 std::vector<uint32_t> words; 1519 1520 if (decoder.DecodeModule(&words) != SPV_SUCCESS) { 1521 return libspirv::DiagnosticStream(position, hijack_context.consumer, 1522 SPV_ERROR_INVALID_BINARY) 1523 << "Unable to decode MARK-V."; 1524 } 1525 1526 assert(!words.empty()); 1527 1528 *spirv_binary = new spv_binary_t(); 1529 (*spirv_binary)->code = new uint32_t[words.size()]; 1530 (*spirv_binary)->wordCount = words.size(); 1531 std::memcpy((*spirv_binary)->code, words.data(), 4 * words.size()); 1532 1533 return SPV_SUCCESS; 1534 } 1535 1536 void spvMarkvBinaryDestroy(spv_markv_binary binary) { 1537 if (!binary) return; 1538 delete[] binary->data; 1539 delete binary; 1540 } 1541 1542 spv_markv_encoder_options spvMarkvEncoderOptionsCreate() { 1543 return new spv_markv_encoder_options_t; 1544 } 1545 1546 void spvMarkvEncoderOptionsDestroy(spv_markv_encoder_options options) { 1547 delete options; 1548 } 1549 1550 spv_markv_decoder_options spvMarkvDecoderOptionsCreate() { 1551 return new spv_markv_decoder_options_t; 1552 } 1553 1554 void spvMarkvDecoderOptionsDestroy(spv_markv_decoder_options options) { 1555 delete options; 1556 } 1557