Home | History | Annotate | Download | only in comp
      1 // Copyright (c) 2018 Google LLC
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include "source/comp/markv.h"
     16 
     17 #include "source/comp/markv_decoder.h"
     18 #include "source/comp/markv_encoder.h"
     19 
     20 namespace spvtools {
     21 namespace comp {
     22 namespace {
     23 
     24 spv_result_t EncodeHeader(void* user_data, spv_endianness_t endian,
     25                           uint32_t magic, uint32_t version, uint32_t generator,
     26                           uint32_t id_bound, uint32_t schema) {
     27   MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
     28   return encoder->EncodeHeader(endian, magic, version, generator, id_bound,
     29                                schema);
     30 }
     31 
     32 spv_result_t EncodeInstruction(void* user_data,
     33                                const spv_parsed_instruction_t* inst) {
     34   MarkvEncoder* encoder = reinterpret_cast<MarkvEncoder*>(user_data);
     35   return encoder->EncodeInstruction(*inst);
     36 }
     37 
     38 }  // namespace
     39 
     40 spv_result_t SpirvToMarkv(
     41     spv_const_context context, const std::vector<uint32_t>& spirv,
     42     const MarkvCodecOptions& options, const MarkvModel& markv_model,
     43     MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
     44     MarkvDebugConsumer debug_consumer, std::vector<uint8_t>* markv) {
     45   spv_context_t hijack_context = *context;
     46   SetContextMessageConsumer(&hijack_context, message_consumer);
     47 
     48   spv_validator_options validator_options =
     49       MarkvDecoder::GetValidatorOptions(options);
     50   if (validator_options) {
     51     spv_const_binary_t spirv_binary = {spirv.data(), spirv.size()};
     52     const spv_result_t result = spvValidateWithOptions(
     53         &hijack_context, validator_options, &spirv_binary, nullptr);
     54     if (result != SPV_SUCCESS) return result;
     55   }
     56 
     57   MarkvEncoder encoder(&hijack_context, options, &markv_model);
     58 
     59   spv_position_t position = {};
     60   if (log_consumer || debug_consumer) {
     61     encoder.CreateLogger(log_consumer, debug_consumer);
     62 
     63     spv_text text = nullptr;
     64     if (spvBinaryToText(&hijack_context, spirv.data(), spirv.size(),
     65                         SPV_BINARY_TO_TEXT_OPTION_NO_HEADER, &text,
     66                         nullptr) != SPV_SUCCESS) {
     67       return DiagnosticStream(position, hijack_context.consumer, "",
     68                               SPV_ERROR_INVALID_BINARY)
     69              << "Failed to disassemble SPIR-V binary.";
     70     }
     71     assert(text);
     72     encoder.SetDisassembly(std::string(text->str, text->length));
     73     spvTextDestroy(text);
     74   }
     75 
     76   if (spvBinaryParse(&hijack_context, &encoder, spirv.data(), spirv.size(),
     77                      EncodeHeader, EncodeInstruction, nullptr) != SPV_SUCCESS) {
     78     return DiagnosticStream(position, hijack_context.consumer, "",
     79                             SPV_ERROR_INVALID_BINARY)
     80            << "Unable to encode to MARK-V.";
     81   }
     82 
     83   *markv = encoder.GetMarkvBinary();
     84   return SPV_SUCCESS;
     85 }
     86 
     87 spv_result_t MarkvToSpirv(
     88     spv_const_context context, const std::vector<uint8_t>& markv,
     89     const MarkvCodecOptions& options, const MarkvModel& markv_model,
     90     MessageConsumer message_consumer, MarkvLogConsumer log_consumer,
     91     MarkvDebugConsumer debug_consumer, std::vector<uint32_t>* spirv) {
     92   spv_position_t position = {};
     93   spv_context_t hijack_context = *context;
     94   SetContextMessageConsumer(&hijack_context, message_consumer);
     95 
     96   MarkvDecoder decoder(&hijack_context, markv, options, &markv_model);
     97 
     98   if (log_consumer || debug_consumer)
     99     decoder.CreateLogger(log_consumer, debug_consumer);
    100 
    101   if (decoder.DecodeModule(spirv) != SPV_SUCCESS) {
    102     return DiagnosticStream(position, hijack_context.consumer, "",
    103                             SPV_ERROR_INVALID_BINARY)
    104            << "Unable to decode MARK-V.";
    105   }
    106 
    107   assert(!spirv->empty());
    108   return SPV_SUCCESS;
    109 }
    110 
    111 }  // namespace comp
    112 }  // namespace spvtools
    113