Home | History | Annotate | Download | only in comp
      1 // Copyright (c) 2018 Google LLC
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/comp/bit_stream.h"
     16 #include "source/comp/markv.h"
     17 #include "source/comp/markv_codec.h"
     18 #include "source/comp/markv_logger.h"
     19 #include "source/util/make_unique.h"
     20 
     21 #ifndef SOURCE_COMP_MARKV_DECODER_H_
     22 #define SOURCE_COMP_MARKV_DECODER_H_
     23 
     24 namespace spvtools {
     25 namespace comp {
     26 
     27 class MarkvLogger;
     28 
     29 // Decodes MARK-V buffers written by MarkvEncoder.
     30 class MarkvDecoder : public MarkvCodec {
     31  public:
     32   // |model| is owned by the caller, must be not null and valid during the
     33   // lifetime of MarkvEncoder.
     34   MarkvDecoder(spv_const_context context, const std::vector<uint8_t>& markv,
     35                const MarkvCodecOptions& options, const MarkvModel* model)
     36       : MarkvCodec(context, GetValidatorOptions(options), model),
     37         options_(options),
     38         reader_(markv) {
     39     SetIdBound(1);
     40     parsed_operands_.reserve(25);
     41     inst_words_.reserve(25);
     42   }
     43   ~MarkvDecoder() = default;
     44 
     45   // Creates an internal logger which writes comments on the decoding process.
     46   void CreateLogger(MarkvLogConsumer log_consumer,
     47                     MarkvDebugConsumer debug_consumer) {
     48     logger_ = MakeUnique<MarkvLogger>(log_consumer, debug_consumer);
     49   }
     50 
     51   // Decodes SPIR-V from MARK-V and stores the words in |spirv_binary|.
     52   // Can be called only once. Fails if data of wrong format or ends prematurely,
     53   // of if validation fails.
     54   spv_result_t DecodeModule(std::vector<uint32_t>* spirv_binary);
     55 
     56   // Creates and returns validator options. Returned value owned by the caller.
     57   static spv_validator_options GetValidatorOptions(
     58       const MarkvCodecOptions& options) {
     59     return options.validate_spirv_binary ? spvValidatorOptionsCreate()
     60                                          : nullptr;
     61   }
     62 
     63  private:
     64   // Describes the format of a typed literal number.
     65   struct NumberType {
     66     spv_number_kind_t type;
     67     uint32_t bit_width;
     68   };
     69 
     70   // Reads a single bit from reader_. The read bit is stored in |bit|.
     71   // Returns false iff reader_ fails.
     72   bool ReadBit(bool* bit) {
     73     uint64_t bits = 0;
     74     const bool result = reader_.ReadBits(&bits, 1);
     75     if (result) *bit = bits ? true : false;
     76     return result;
     77   };
     78 
     79   // Returns ReadBit bound to the class object.
     80   std::function<bool(bool*)> GetReadBitCallback() {
     81     return std::bind(&MarkvDecoder::ReadBit, this, std::placeholders::_1);
     82   }
     83 
     84   // Reads a single non-id word from bit stream. operand_.type determines if
     85   // the word needs to be decoded and how.
     86   spv_result_t DecodeNonIdWord(uint32_t* word);
     87 
     88   // Reads and decodes both opcode and num_operands as a single code.
     89   // Returns SPV_UNSUPPORTED iff no suitable codec was found.
     90   spv_result_t DecodeOpcodeAndNumberOfOperands(uint32_t* opcode,
     91                                                uint32_t* num_operands);
     92 
     93   // Reads mtf rank from bit stream. |mtf| is used to determine the codec
     94   // scheme. |fallback_method| is used if no codec defined for |mtf|.
     95   spv_result_t DecodeMtfRankHuffman(uint64_t mtf, uint32_t fallback_method,
     96                                     uint32_t* rank);
     97 
     98   // Reads id using coding based on mtf associated with the id descriptor.
     99   // Returns SPV_UNSUPPORTED iff fallback method needs to be used.
    100   spv_result_t DecodeIdWithDescriptor(uint32_t* id);
    101 
    102   // Reads id using coding based on the given |mtf|, which is expected to
    103   // contain the needed |id|.
    104   spv_result_t DecodeExistingId(uint64_t mtf, uint32_t* id);
    105 
    106   // Reads type id of the current instruction if can't be inferred.
    107   spv_result_t DecodeTypeId();
    108 
    109   // Reads result id of the current instruction if can't be inferred.
    110   spv_result_t DecodeResultId();
    111 
    112   // Reads id which is neither type nor result id.
    113   spv_result_t DecodeRefId(uint32_t* id);
    114 
    115   // Reads and discards bits until the beginning of the next byte if the
    116   // number of bits until the next byte is less than |byte_break_if_less_than|.
    117   bool ReadToByteBreak(size_t byte_break_if_less_than);
    118 
    119   // Returns instruction words decoded up to this point.
    120   const uint32_t* GetInstWords() const override { return inst_words_.data(); }
    121 
    122   // Reads a literal number as it is described in |operand| from the bit stream,
    123   // decodes and writes it to spirv_.
    124   spv_result_t DecodeLiteralNumber(const spv_parsed_operand_t& operand);
    125 
    126   // Reads instruction from bit stream, decodes and validates it.
    127   // Decoded instruction is valid until the next call of DecodeInstruction().
    128   spv_result_t DecodeInstruction();
    129 
    130   // Read operand from the stream decodes and validates it.
    131   spv_result_t DecodeOperand(size_t operand_offset,
    132                              const spv_operand_type_t type,
    133                              spv_operand_pattern_t* expected_operands);
    134 
    135   // Records the numeric type for an operand according to the type information
    136   // associated with the given non-zero type Id.  This can fail if the type Id
    137   // is not a type Id, or if the type Id does not reference a scalar numeric
    138   // type.  On success, return SPV_SUCCESS and populates the num_words,
    139   // number_kind, and number_bit_width fields of parsed_operand.
    140   spv_result_t SetNumericTypeInfoForType(spv_parsed_operand_t* parsed_operand,
    141                                          uint32_t type_id);
    142 
    143   // Records the number type for the current instruction, if it generates a
    144   // type. For types that aren't scalar numbers, record something with number
    145   // kind SPV_NUMBER_NONE.
    146   void RecordNumberType();
    147 
    148   MarkvCodecOptions options_;
    149 
    150   // Temporary sink where decoded SPIR-V words are written. Once it contains the
    151   // entire module, the container is moved and returned.
    152   std::vector<uint32_t> spirv_;
    153 
    154   // Bit stream containing encoded data.
    155   BitReaderWord64 reader_;
    156 
    157   // Temporary storage for operands of the currently parsed instruction.
    158   // Valid until next DecodeInstruction call.
    159   std::vector<spv_parsed_operand_t> parsed_operands_;
    160 
    161   // Temporary storage for current instruction words.
    162   // Valid until next DecodeInstruction call.
    163   std::vector<uint32_t> inst_words_;
    164 
    165   // Maps a type ID to its number type description.
    166   std::unordered_map<uint32_t, NumberType> type_id_to_number_type_info_;
    167 
    168   // Maps an ExtInstImport id to the extended instruction type.
    169   std::unordered_map<uint32_t, spv_ext_inst_type_t> import_id_to_ext_inst_type_;
    170 };
    171 
    172 }  // namespace comp
    173 }  // namespace spvtools
    174 
    175 #endif  // SOURCE_COMP_MARKV_DECODER_H_
    176