Home | History | Annotate | Download | only in comp
      1 // Copyright (c) 2017 Google Inc.
      2 //
      3 // Licensed under the Apache License, Version 2.0 (the "License");
      4 // you may not use this file except in compliance with the License.
      5 // You may obtain a copy of the License at
      6 //
      7 //     http://www.apache.org/licenses/LICENSE-2.0
      8 //
      9 // Unless required by applicable law or agreed to in writing, software
     10 // distributed under the License is distributed on an "AS IS" BASIS,
     11 // WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 // See the License for the specific language governing permissions and
     13 // limitations under the License.
     14 
     15 #include <algorithm>
     16 #include <cassert>
     17 #include <cstdio>
     18 #include <cstring>
     19 #include <functional>
     20 #include <iostream>
     21 #include <memory>
     22 #include <string>
     23 #include <utility>
     24 #include <vector>
     25 
     26 #include "source/comp/markv.h"
     27 #include "source/spirv_target_env.h"
     28 #include "source/table.h"
     29 #include "spirv-tools/optimizer.hpp"
     30 #include "tools/comp/markv_model_factory.h"
     31 #include "tools/io.h"
     32 
     33 namespace {
     34 
     35 const auto kSpvEnv = SPV_ENV_UNIVERSAL_1_2;
     36 
     37 enum Task {
     38   kNoTask = 0,
     39   kEncode,
     40   kDecode,
     41   kTest,
     42 };
     43 
     44 struct ScopedContext {
     45   ScopedContext(spv_target_env env) : context(spvContextCreate(env)) {}
     46   ~ScopedContext() { spvContextDestroy(context); }
     47   spv_context context;
     48 };
     49 
     50 void print_usage(char* argv0) {
     51   printf(
     52       R"(%s - Encodes or decodes a SPIR-V binary to or from a MARK-V binary.
     53 
     54 USAGE: %s [e|d|t] [options] [<filename>]
     55 
     56 The input binary is read from <filename>. If no file is specified,
     57 or if the filename is "-", then the binary is read from standard input.
     58 
     59 If no output is specified then the output is printed to stdout in a human
     60 readable format.
     61 
     62 WIP: MARK-V codec is in early stages of development. At the moment it only
     63 can encode and decode some SPIR-V files and only if exacly the same build of
     64 software is used (is doesn't write or handle version numbers yet).
     65 
     66 Tasks:
     67   e               Encode SPIR-V to MARK-V.
     68   d               Decode MARK-V to SPIR-V.
     69   t               Test the codec by first encoding the given SPIR-V file to
     70                   MARK-V, then decoding it back to SPIR-V and comparing results.
     71 
     72 Options:
     73   -h, --help      Print this help.
     74   --comments      Write codec comments to stderr.
     75   --version       Display MARK-V codec version.
     76   --validate      Validate SPIR-V while encoding or decoding.
     77   --model=<model-name>
     78                   Compression model, possible values:
     79                   shader_lite - fast, poor compression ratio
     80                   shader_mid - balanced
     81                   shader_max - best compression ratio
     82                   Default: shader_lite
     83 
     84   -o <filename>   Set the output filename.
     85                   Output goes to standard output if this option is
     86                   not specified, or if the filename is "-".
     87                   Not needed for 't' task (testing).
     88 )",
     89       argv0, argv0);
     90 }
     91 
     92 void DiagnosticsMessageHandler(spv_message_level_t level, const char*,
     93                                const spv_position_t& position,
     94                                const char* message) {
     95   switch (level) {
     96     case SPV_MSG_FATAL:
     97     case SPV_MSG_INTERNAL_ERROR:
     98     case SPV_MSG_ERROR:
     99       std::cerr << "error: " << position.index << ": " << message << std::endl;
    100       break;
    101     case SPV_MSG_WARNING:
    102       std::cerr << "warning: " << position.index << ": " << message
    103                 << std::endl;
    104       break;
    105     case SPV_MSG_INFO:
    106       std::cerr << "info: " << position.index << ": " << message << std::endl;
    107       break;
    108     default:
    109       break;
    110   }
    111 }
    112 
    113 }  // namespace
    114 
    115 int main(int argc, char** argv) {
    116   const char* input_filename = nullptr;
    117   const char* output_filename = nullptr;
    118 
    119   Task task = kNoTask;
    120 
    121   if (argc < 3) {
    122     print_usage(argv[0]);
    123     return 0;
    124   }
    125 
    126   const char* task_char = argv[1];
    127   if (0 == strcmp("e", task_char)) {
    128     task = kEncode;
    129   } else if (0 == strcmp("d", task_char)) {
    130     task = kDecode;
    131   } else if (0 == strcmp("t", task_char)) {
    132     task = kTest;
    133   }
    134 
    135   if (task == kNoTask) {
    136     print_usage(argv[0]);
    137     return 1;
    138   }
    139 
    140   bool want_comments = false;
    141   bool validate_spirv_binary = false;
    142 
    143   spvtools::comp::MarkvModelType model_type =
    144       spvtools::comp::kMarkvModelUnknown;
    145 
    146   for (int argi = 2; argi < argc; ++argi) {
    147     if ('-' == argv[argi][0]) {
    148       switch (argv[argi][1]) {
    149         case 'h':
    150           print_usage(argv[0]);
    151           return 0;
    152         case 'o': {
    153           if (!output_filename && argi + 1 < argc &&
    154               (task == kEncode || task == kDecode)) {
    155             output_filename = argv[++argi];
    156           } else {
    157             print_usage(argv[0]);
    158             return 1;
    159           }
    160         } break;
    161         case '-': {
    162           if (0 == strcmp(argv[argi], "--help")) {
    163             print_usage(argv[0]);
    164             return 0;
    165           } else if (0 == strcmp(argv[argi], "--comments")) {
    166             want_comments = true;
    167           } else if (0 == strcmp(argv[argi], "--version")) {
    168             fprintf(stderr, "error: Not implemented\n");
    169             return 1;
    170           } else if (0 == strcmp(argv[argi], "--validate")) {
    171             validate_spirv_binary = true;
    172           } else if (0 == strcmp(argv[argi], "--model=shader_lite")) {
    173             if (model_type != spvtools::comp::kMarkvModelUnknown)
    174               fprintf(stderr, "error: More than one model specified\n");
    175             model_type = spvtools::comp::kMarkvModelShaderLite;
    176           } else if (0 == strcmp(argv[argi], "--model=shader_mid")) {
    177             if (model_type != spvtools::comp::kMarkvModelUnknown)
    178               fprintf(stderr, "error: More than one model specified\n");
    179             model_type = spvtools::comp::kMarkvModelShaderMid;
    180           } else if (0 == strcmp(argv[argi], "--model=shader_max")) {
    181             if (model_type != spvtools::comp::kMarkvModelUnknown)
    182               fprintf(stderr, "error: More than one model specified\n");
    183             model_type = spvtools::comp::kMarkvModelShaderMax;
    184           } else {
    185             print_usage(argv[0]);
    186             return 1;
    187           }
    188         } break;
    189         case '\0': {
    190           // Setting a filename of "-" to indicate stdin.
    191           if (!input_filename) {
    192             input_filename = argv[argi];
    193           } else {
    194             fprintf(stderr, "error: More than one input file specified\n");
    195             return 1;
    196           }
    197         } break;
    198         default:
    199           print_usage(argv[0]);
    200           return 1;
    201       }
    202     } else {
    203       if (!input_filename) {
    204         input_filename = argv[argi];
    205       } else {
    206         fprintf(stderr, "error: More than one input file specified\n");
    207         return 1;
    208       }
    209     }
    210   }
    211 
    212   if (model_type == spvtools::comp::kMarkvModelUnknown)
    213     model_type = spvtools::comp::kMarkvModelShaderLite;
    214 
    215   const auto no_comments = spvtools::comp::MarkvLogConsumer();
    216   const auto output_to_stderr = [](const std::string& str) {
    217     std::cerr << str;
    218   };
    219 
    220   ScopedContext ctx(kSpvEnv);
    221 
    222   std::unique_ptr<spvtools::comp::MarkvModel> model =
    223       spvtools::comp::CreateMarkvModel(model_type);
    224 
    225   std::vector<uint32_t> spirv;
    226   std::vector<uint8_t> markv;
    227 
    228   spvtools::comp::MarkvCodecOptions options;
    229   options.validate_spirv_binary = validate_spirv_binary;
    230 
    231   if (task == kEncode) {
    232     if (!ReadFile<uint32_t>(input_filename, "rb", &spirv)) return 1;
    233     assert(!spirv.empty());
    234 
    235     if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv(
    236                            ctx.context, spirv, options, *model,
    237                            DiagnosticsMessageHandler,
    238                            want_comments ? output_to_stderr : no_comments,
    239                            spvtools::comp::MarkvDebugConsumer(), &markv)) {
    240       std::cerr << "error: Failed to encode " << input_filename << " to MARK-V "
    241                 << std::endl;
    242       return 1;
    243     }
    244 
    245     if (!WriteFile<uint8_t>(output_filename, "wb", markv.data(), markv.size()))
    246       return 1;
    247   } else if (task == kDecode) {
    248     if (!ReadFile<uint8_t>(input_filename, "rb", &markv)) return 1;
    249     assert(!markv.empty());
    250 
    251     if (SPV_SUCCESS != spvtools::comp::MarkvToSpirv(
    252                            ctx.context, markv, options, *model,
    253                            DiagnosticsMessageHandler,
    254                            want_comments ? output_to_stderr : no_comments,
    255                            spvtools::comp::MarkvDebugConsumer(), &spirv)) {
    256       std::cerr << "error: Failed to decode " << input_filename << " to SPIR-V "
    257                 << std::endl;
    258       return 1;
    259     }
    260 
    261     if (!WriteFile<uint32_t>(output_filename, "wb", spirv.data(), spirv.size()))
    262       return 1;
    263   } else if (task == kTest) {
    264     if (!ReadFile<uint32_t>(input_filename, "rb", &spirv)) return 1;
    265     assert(!spirv.empty());
    266 
    267     std::vector<uint32_t> spirv_before;
    268     spvtools::Optimizer optimizer(kSpvEnv);
    269     optimizer.RegisterPass(spvtools::CreateCompactIdsPass());
    270     if (!optimizer.Run(spirv.data(), spirv.size(), &spirv_before)) {
    271       std::cerr << "error: Optimizer failure on: " << input_filename
    272                 << std::endl;
    273     }
    274 
    275     std::vector<std::string> encoder_instruction_bits;
    276     std::vector<std::string> encoder_instruction_comments;
    277     std::vector<std::vector<uint32_t>> encoder_instruction_words;
    278     std::vector<std::string> decoder_instruction_bits;
    279     std::vector<std::string> decoder_instruction_comments;
    280     std::vector<std::vector<uint32_t>> decoder_instruction_words;
    281 
    282     const auto encoder_debug_consumer = [&](const std::vector<uint32_t>& words,
    283                                             const std::string& bits,
    284                                             const std::string& comment) {
    285       encoder_instruction_words.push_back(words);
    286       encoder_instruction_bits.push_back(bits);
    287       encoder_instruction_comments.push_back(comment);
    288       return true;
    289     };
    290 
    291     if (SPV_SUCCESS != spvtools::comp::SpirvToMarkv(
    292                            ctx.context, spirv_before, options, *model,
    293                            DiagnosticsMessageHandler,
    294                            want_comments ? output_to_stderr : no_comments,
    295                            encoder_debug_consumer, &markv)) {
    296       std::cerr << "error: Failed to encode " << input_filename << " to MARK-V "
    297                 << std::endl;
    298       return 1;
    299     }
    300 
    301     const auto write_bug_report = [&]() {
    302       for (size_t inst_index = 0; inst_index < decoder_instruction_words.size();
    303            ++inst_index) {
    304         std::cerr << "\nInstruction #" << inst_index << std::endl;
    305         std::cerr << "\nEncoder words: ";
    306         for (uint32_t word : encoder_instruction_words[inst_index])
    307           std::cerr << word << " ";
    308         std::cerr << "\nDecoder words: ";
    309         for (uint32_t word : decoder_instruction_words[inst_index])
    310           std::cerr << word << " ";
    311         std::cerr << std::endl;
    312 
    313         std::cerr << "\nEncoder bits: " << encoder_instruction_bits[inst_index];
    314         std::cerr << "\nDecoder bits: " << decoder_instruction_bits[inst_index];
    315         std::cerr << std::endl;
    316 
    317         std::cerr << "\nEncoder comments:\n"
    318                   << encoder_instruction_comments[inst_index];
    319         std::cerr << "Decoder comments:\n"
    320                   << decoder_instruction_comments[inst_index];
    321         std::cerr << std::endl;
    322       }
    323     };
    324 
    325     const auto decoder_debug_consumer = [&](const std::vector<uint32_t>& words,
    326                                             const std::string& bits,
    327                                             const std::string& comment) {
    328       const size_t inst_index = decoder_instruction_words.size();
    329       if (inst_index >= encoder_instruction_words.size()) {
    330         write_bug_report();
    331         std::cerr << "error: Decoder has more instructions than encoder: "
    332                   << input_filename << std::endl;
    333         return false;
    334       }
    335 
    336       decoder_instruction_words.push_back(words);
    337       decoder_instruction_bits.push_back(bits);
    338       decoder_instruction_comments.push_back(comment);
    339 
    340       if (encoder_instruction_words[inst_index] !=
    341           decoder_instruction_words[inst_index]) {
    342         write_bug_report();
    343         std::cerr << "error: Words of the last decoded instruction differ from "
    344                      "reference: "
    345                   << input_filename << std::endl;
    346         return false;
    347       }
    348 
    349       if (encoder_instruction_bits[inst_index] !=
    350           decoder_instruction_bits[inst_index]) {
    351         write_bug_report();
    352         std::cerr << "error: Bits of the last decoded instruction differ from "
    353                      "reference: "
    354                   << input_filename << std::endl;
    355         return false;
    356       }
    357       return true;
    358     };
    359 
    360     std::vector<uint32_t> spirv_after;
    361     const spv_result_t decoding_result = spvtools::comp::MarkvToSpirv(
    362         ctx.context, markv, options, *model, DiagnosticsMessageHandler,
    363         want_comments ? output_to_stderr : no_comments, decoder_debug_consumer,
    364         &spirv_after);
    365 
    366     if (decoding_result == SPV_REQUESTED_TERMINATION) {
    367       std::cerr << "error: Decoding interrupted by the debugger: "
    368                 << input_filename << std::endl;
    369       return 1;
    370     }
    371 
    372     if (decoding_result != SPV_SUCCESS) {
    373       std::cerr << "error: Failed to decode encoded " << input_filename
    374                 << " back to SPIR-V " << std::endl;
    375       return 1;
    376     }
    377 
    378     assert(spirv_before.size() == spirv_after.size());
    379     assert(std::mismatch(std::next(spirv_before.begin(), 5), spirv_before.end(),
    380                          std::next(spirv_after.begin(), 5)) ==
    381            std::make_pair(spirv_before.end(), spirv_after.end()));
    382   }
    383 
    384   return 0;
    385 }
    386