Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 #include "tensorflow/contrib/lite/toco/model_cmdline_flags.h"
     16 
     17 #include <string>
     18 #include <vector>
     19 
     20 #include "absl/strings/numbers.h"
     21 #include "absl/strings/str_join.h"
     22 #include "absl/strings/str_split.h"
     23 #include "absl/strings/string_view.h"
     24 #include "absl/strings/strip.h"
     25 #include "tensorflow/contrib/lite/toco/args.h"
     26 #include "tensorflow/contrib/lite/toco/toco_graphviz_dump_options.h"
     27 #include "tensorflow/contrib/lite/toco/toco_port.h"
     28 #include "tensorflow/core/platform/logging.h"
     29 #include "tensorflow/core/util/command_line_flags.h"
     30 
     31 // "batch" flag only exists internally
     32 #ifdef PLATFORM_GOOGLE
     33 #include "base/commandlineflags.h"
     34 #endif
     35 
     36 namespace toco {
     37 
     38 bool ParseModelFlagsFromCommandLineFlags(
     39     int* argc, char* argv[], string* msg,
     40     ParsedModelFlags* parsed_model_flags_ptr) {
     41   ParsedModelFlags& parsed_flags = *parsed_model_flags_ptr;
     42   using tensorflow::Flag;
     43   std::vector<tensorflow::Flag> flags = {
     44       Flag("input_array", parsed_flags.input_array.bind(),
     45            parsed_flags.input_array.default_value(),
     46            "Deprecated: use --input_arrays instead. Name of the input array. "
     47            "If not specified, will try to read "
     48            "that information from the input file."),
     49       Flag("input_arrays", parsed_flags.input_arrays.bind(),
     50            parsed_flags.input_arrays.default_value(),
     51            "Names of the output arrays, comma-separated. If not specified, "
     52            "will try to read that information from the input file."),
     53       Flag("output_array", parsed_flags.output_array.bind(),
     54            parsed_flags.output_array.default_value(),
     55            "Deprecated: use --output_arrays instead. Name of the output array, "
     56            "when specifying a unique output array. "
     57            "If not specified, will try to read that information from the "
     58            "input file."),
     59       Flag("output_arrays", parsed_flags.output_arrays.bind(),
     60            parsed_flags.output_arrays.default_value(),
     61            "Names of the output arrays, comma-separated. "
     62            "If not specified, will try to read "
     63            "that information from the input file."),
     64       Flag("input_shape", parsed_flags.input_shape.bind(),
     65            parsed_flags.input_shape.default_value(),
     66            "Deprecated: use --input_shapes instead. Input array shape. For "
     67            "many models the shape takes the form "
     68            "batch size, input array height, input array width, input array "
     69            "depth."),
     70       Flag("input_shapes", parsed_flags.input_shapes.bind(),
     71            parsed_flags.input_shapes.default_value(),
     72            "Shapes corresponding to --input_arrays, colon-separated. For "
     73            "many models each shape takes the form batch size, input array "
     74            "height, input array width, input array depth."),
     75       Flag("input_data_type", parsed_flags.input_data_type.bind(),
     76            parsed_flags.input_data_type.default_value(),
     77            "Deprecated: use --input_data_types instead. Input array type, if "
     78            "not already provided in the graph. "
     79            "Typically needs to be specified when passing arbitrary arrays "
     80            "to --input_array."),
     81       Flag("input_data_types", parsed_flags.input_data_types.bind(),
     82            parsed_flags.input_data_types.default_value(),
     83            "Input arrays types, comma-separated, if not already provided in "
     84            "the graph. "
     85            "Typically needs to be specified when passing arbitrary arrays "
     86            "to --input_arrays."),
     87       Flag("mean_value", parsed_flags.mean_value.bind(),
     88            parsed_flags.mean_value.default_value(),
     89            "Deprecated: use --mean_values instead. mean_value parameter for "
     90            "image models, used to compute input "
     91            "activations from input pixel data."),
     92       Flag("mean_values", parsed_flags.mean_values.bind(),
     93            parsed_flags.mean_values.default_value(),
     94            "mean_values parameter for image models, comma-separated list of "
     95            "doubles, used to compute input activations from input pixel "
     96            "data. Each entry in the list should match an entry in "
     97            "--input_arrays."),
     98       Flag("std_value", parsed_flags.std_value.bind(),
     99            parsed_flags.std_value.default_value(),
    100            "Deprecated: use --std_values instead. std_value parameter for "
    101            "image models, used to compute input "
    102            "activations from input pixel data."),
    103       Flag("std_values", parsed_flags.std_values.bind(),
    104            parsed_flags.std_values.default_value(),
    105            "std_value parameter for image models, comma-separated list of "
    106            "doubles, used to compute input activations from input pixel "
    107            "data. Each entry in the list should match an entry in "
    108            "--input_arrays."),
    109       Flag("variable_batch", parsed_flags.variable_batch.bind(),
    110            parsed_flags.variable_batch.default_value(),
    111            "If true, the model accepts an arbitrary batch size. Mutually "
    112            "exclusive "
    113            "with the 'batch' field: at most one of these two fields can be "
    114            "set."),
    115       Flag("rnn_states", parsed_flags.rnn_states.bind(),
    116            parsed_flags.rnn_states.default_value(), ""),
    117       Flag("model_checks", parsed_flags.model_checks.bind(),
    118            parsed_flags.model_checks.default_value(),
    119            "A list of model checks to be applied to verify the form of the "
    120            "model.  Applied after the graph transformations after import."),
    121       Flag("graphviz_first_array", parsed_flags.graphviz_first_array.bind(),
    122            parsed_flags.graphviz_first_array.default_value(),
    123            "If set, defines the start of the sub-graph to be dumped to "
    124            "GraphViz."),
    125       Flag(
    126           "graphviz_last_array", parsed_flags.graphviz_last_array.bind(),
    127           parsed_flags.graphviz_last_array.default_value(),
    128           "If set, defines the end of the sub-graph to be dumped to GraphViz."),
    129       Flag("dump_graphviz", parsed_flags.dump_graphviz.bind(),
    130            parsed_flags.dump_graphviz.default_value(),
    131            "Dump graphviz during LogDump call. If string is non-empty then "
    132            "it defines path to dump, otherwise will skip dumping."),
    133       Flag("dump_graphviz_video", parsed_flags.dump_graphviz_video.bind(),
    134            parsed_flags.dump_graphviz_video.default_value(),
    135            "If true, will dump graphviz at each "
    136            "graph transformation, which may be used to generate a video."),
    137       Flag("allow_nonexistent_arrays",
    138            parsed_flags.allow_nonexistent_arrays.bind(),
    139            parsed_flags.allow_nonexistent_arrays.default_value(),
    140            "If true, will allow passing inexistent arrays in --input_arrays "
    141            "and --output_arrays. This makes little sense, is only useful to "
    142            "more easily get graph visualizations."),
    143       Flag("allow_nonascii_arrays", parsed_flags.allow_nonascii_arrays.bind(),
    144            parsed_flags.allow_nonascii_arrays.default_value(),
    145            "If true, will allow passing non-ascii-printable characters in "
    146            "--input_arrays and --output_arrays. By default (if false), only "
    147            "ascii printable characters are allowed, i.e. character codes "
    148            "ranging from 32 to 127. This is disallowed by default so as to "
    149            "catch common copy-and-paste issues where invisible unicode "
    150            "characters are unwittingly added to these strings."),
    151       Flag(
    152           "arrays_extra_info_file", parsed_flags.arrays_extra_info_file.bind(),
    153           parsed_flags.arrays_extra_info_file.default_value(),
    154           "Path to an optional file containing a serialized ArraysExtraInfo "
    155           "proto allowing to pass extra information about arrays not specified "
    156           "in the input model file, such as extra MinMax information."),
    157   };
    158   bool asked_for_help =
    159       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
    160   if (asked_for_help) {
    161     *msg += tensorflow::Flags::Usage(argv[0], flags);
    162     return false;
    163   } else {
    164     if (!tensorflow::Flags::Parse(argc, argv, flags)) return false;
    165   }
    166   auto& dump_options = *GraphVizDumpOptions::singleton();
    167   dump_options.graphviz_first_array = parsed_flags.graphviz_first_array.value();
    168   dump_options.graphviz_last_array = parsed_flags.graphviz_last_array.value();
    169   dump_options.dump_graphviz_video = parsed_flags.dump_graphviz_video.value();
    170   dump_options.dump_graphviz = parsed_flags.dump_graphviz.value();
    171 
    172   return true;
    173 }
    174 
    175 void ReadModelFlagsFromCommandLineFlags(
    176     const ParsedModelFlags& parsed_model_flags, ModelFlags* model_flags) {
    177   toco::port::CheckInitGoogleIsDone("InitGoogle is not done yet");
    178 
    179 // "batch" flag only exists internally
    180 #ifdef PLATFORM_GOOGLE
    181   CHECK(!((base::SpecifiedOnCommandLine("batch") &&
    182            parsed_model_flags.variable_batch.specified())))
    183       << "The --batch and --variable_batch flags are mutually exclusive.";
    184 #endif
    185   CHECK(!(parsed_model_flags.output_array.specified() &&
    186           parsed_model_flags.output_arrays.specified()))
    187       << "The --output_array and --vs flags are mutually exclusive.";
    188 
    189   if (parsed_model_flags.output_array.specified()) {
    190     model_flags->add_output_arrays(parsed_model_flags.output_array.value());
    191   }
    192 
    193   if (parsed_model_flags.output_arrays.specified()) {
    194     std::vector<string> output_arrays =
    195         absl::StrSplit(parsed_model_flags.output_arrays.value(), ',');
    196     for (const string& output_array : output_arrays) {
    197       model_flags->add_output_arrays(output_array);
    198     }
    199   }
    200 
    201   const bool uses_single_input_flags =
    202       parsed_model_flags.input_array.specified() ||
    203       parsed_model_flags.mean_value.specified() ||
    204       parsed_model_flags.std_value.specified() ||
    205       parsed_model_flags.input_shape.specified();
    206 
    207   const bool uses_multi_input_flags =
    208       parsed_model_flags.input_arrays.specified() ||
    209       parsed_model_flags.mean_values.specified() ||
    210       parsed_model_flags.std_values.specified() ||
    211       parsed_model_flags.input_shapes.specified();
    212 
    213   QCHECK(!(uses_single_input_flags && uses_multi_input_flags))
    214       << "Use either the singular-form input flags (--input_array, "
    215          "--input_shape, --mean_value, --std_value) or the plural form input "
    216          "flags (--input_arrays, --input_shapes, --mean_values, --std_values), "
    217          "but not both forms within the same command line.";
    218 
    219   if (parsed_model_flags.input_array.specified()) {
    220     QCHECK(uses_single_input_flags);
    221     model_flags->add_input_arrays()->set_name(
    222         parsed_model_flags.input_array.value());
    223   }
    224   if (parsed_model_flags.input_arrays.specified()) {
    225     QCHECK(uses_multi_input_flags);
    226     for (const auto& input_array :
    227          absl::StrSplit(parsed_model_flags.input_arrays.value(), ',')) {
    228       model_flags->add_input_arrays()->set_name(string(input_array));
    229     }
    230   }
    231   if (parsed_model_flags.mean_value.specified()) {
    232     QCHECK(uses_single_input_flags);
    233     model_flags->mutable_input_arrays(0)->set_mean_value(
    234         parsed_model_flags.mean_value.value());
    235   }
    236   if (parsed_model_flags.mean_values.specified()) {
    237     QCHECK(uses_multi_input_flags);
    238     std::vector<string> mean_values =
    239         absl::StrSplit(parsed_model_flags.mean_values.value(), ',');
    240     QCHECK(mean_values.size() == model_flags->input_arrays_size());
    241     for (int i = 0; i < mean_values.size(); ++i) {
    242       char* last = nullptr;
    243       model_flags->mutable_input_arrays(i)->set_mean_value(
    244           strtod(mean_values[i].data(), &last));
    245       CHECK(last != mean_values[i].data());
    246     }
    247   }
    248   if (parsed_model_flags.std_value.specified()) {
    249     QCHECK(uses_single_input_flags);
    250     model_flags->mutable_input_arrays(0)->set_std_value(
    251         parsed_model_flags.std_value.value());
    252   }
    253   if (parsed_model_flags.std_values.specified()) {
    254     QCHECK(uses_multi_input_flags);
    255     std::vector<string> std_values =
    256         absl::StrSplit(parsed_model_flags.std_values.value(), ',');
    257     QCHECK(std_values.size() == model_flags->input_arrays_size());
    258     for (int i = 0; i < std_values.size(); ++i) {
    259       char* last = nullptr;
    260       model_flags->mutable_input_arrays(i)->set_std_value(
    261           strtod(std_values[i].data(), &last));
    262       CHECK(last != std_values[i].data());
    263     }
    264   }
    265   if (parsed_model_flags.input_data_type.specified()) {
    266     QCHECK(uses_single_input_flags);
    267     IODataType type;
    268     QCHECK(IODataType_Parse(parsed_model_flags.input_data_type.value(), &type));
    269     model_flags->mutable_input_arrays(0)->set_data_type(type);
    270   }
    271   if (parsed_model_flags.input_data_types.specified()) {
    272     QCHECK(uses_multi_input_flags);
    273     std::vector<string> input_data_types =
    274         absl::StrSplit(parsed_model_flags.input_data_types.value(), ',');
    275     QCHECK(input_data_types.size() == model_flags->input_arrays_size());
    276     for (int i = 0; i < input_data_types.size(); ++i) {
    277       IODataType type;
    278       QCHECK(IODataType_Parse(input_data_types[i], &type));
    279       model_flags->mutable_input_arrays(i)->set_data_type(type);
    280     }
    281   }
    282   if (parsed_model_flags.input_shape.specified()) {
    283     QCHECK(uses_single_input_flags);
    284     if (model_flags->input_arrays().empty()) {
    285       model_flags->add_input_arrays();
    286     }
    287     auto* shape = model_flags->mutable_input_arrays(0)->mutable_shape();
    288     shape->clear_dims();
    289     const IntList& list = parsed_model_flags.input_shape.value();
    290     for (auto& dim : list.elements) {
    291       shape->add_dims(dim);
    292     }
    293   }
    294   if (parsed_model_flags.input_shapes.specified()) {
    295     QCHECK(uses_multi_input_flags);
    296     std::vector<string> input_shapes =
    297         absl::StrSplit(parsed_model_flags.input_shapes.value(), ':');
    298     QCHECK(input_shapes.size() == model_flags->input_arrays_size());
    299     for (int i = 0; i < input_shapes.size(); ++i) {
    300       auto* shape = model_flags->mutable_input_arrays(i)->mutable_shape();
    301       shape->clear_dims();
    302       for (const auto& dim_str : absl::StrSplit(input_shapes[i], ',')) {
    303         int size;
    304         CHECK(absl::SimpleAtoi(dim_str, &size))
    305             << "Failed to parse input_shape: " << input_shapes[i];
    306         shape->add_dims(size);
    307       }
    308     }
    309   }
    310 
    311 #define READ_MODEL_FLAG(name)                                   \
    312   do {                                                          \
    313     if (parsed_model_flags.name.specified()) {                  \
    314       model_flags->set_##name(parsed_model_flags.name.value()); \
    315     }                                                           \
    316   } while (false)
    317 
    318   READ_MODEL_FLAG(variable_batch);
    319 
    320 #undef READ_MODEL_FLAG
    321 
    322   for (const auto& element : parsed_model_flags.rnn_states.value().elements) {
    323     auto* rnn_state_proto = model_flags->add_rnn_states();
    324     for (const auto& kv_pair : element) {
    325       const string& key = kv_pair.first;
    326       const string& value = kv_pair.second;
    327       if (key == "state_array") {
    328         rnn_state_proto->set_state_array(value);
    329       } else if (key == "back_edge_source_array") {
    330         rnn_state_proto->set_back_edge_source_array(value);
    331       } else if (key == "size") {
    332         int32 size = 0;
    333         CHECK(absl::SimpleAtoi(value, &size));
    334         CHECK_GT(size, 0);
    335         rnn_state_proto->set_size(size);
    336       } else {
    337         LOG(FATAL) << "Unknown key '" << key << "' in --rnn_states";
    338       }
    339     }
    340     CHECK(rnn_state_proto->has_state_array() &&
    341           rnn_state_proto->has_back_edge_source_array() &&
    342           rnn_state_proto->has_size())
    343         << "--rnn_states must include state_array, back_edge_source_array and "
    344            "size.";
    345   }
    346 
    347   for (const auto& element : parsed_model_flags.model_checks.value().elements) {
    348     auto* model_check_proto = model_flags->add_model_checks();
    349     for (const auto& kv_pair : element) {
    350       const string& key = kv_pair.first;
    351       const string& value = kv_pair.second;
    352       if (key == "count_type") {
    353         model_check_proto->set_count_type(value);
    354       } else if (key == "count_min") {
    355         int32 count = 0;
    356         CHECK(absl::SimpleAtoi(value, &count));
    357         CHECK_GE(count, -1);
    358         model_check_proto->set_count_min(count);
    359       } else if (key == "count_max") {
    360         int32 count = 0;
    361         CHECK(absl::SimpleAtoi(value, &count));
    362         CHECK_GE(count, -1);
    363         model_check_proto->set_count_max(count);
    364       } else {
    365         LOG(FATAL) << "Unknown key '" << key << "' in --model_checks";
    366       }
    367     }
    368   }
    369 
    370   model_flags->set_allow_nonascii_arrays(
    371       parsed_model_flags.allow_nonascii_arrays.value());
    372   model_flags->set_allow_nonexistent_arrays(
    373       parsed_model_flags.allow_nonexistent_arrays.value());
    374 
    375   if (parsed_model_flags.arrays_extra_info_file.specified()) {
    376     string arrays_extra_info_file_contents;
    377     port::file::GetContents(parsed_model_flags.arrays_extra_info_file.value(),
    378                             &arrays_extra_info_file_contents,
    379                             port::file::Defaults());
    380     ParseFromStringEitherTextOrBinary(arrays_extra_info_file_contents,
    381                                       model_flags->mutable_arrays_extra_info());
    382   }
    383 }
    384 
    385 ParsedModelFlags* UncheckedGlobalParsedModelFlags(bool must_already_exist) {
    386   static auto* flags = [must_already_exist]() {
    387     if (must_already_exist) {
    388       fprintf(stderr, __FILE__
    389               ":"
    390               "GlobalParsedModelFlags() used without initialization\n");
    391       fflush(stderr);
    392       abort();
    393     }
    394     return new toco::ParsedModelFlags;
    395   }();
    396   return flags;
    397 }
    398 
    399 ParsedModelFlags* GlobalParsedModelFlags() {
    400   return UncheckedGlobalParsedModelFlags(true);
    401 }
    402 
    403 void ParseModelFlagsOrDie(int* argc, char* argv[]) {
    404   // TODO(aselle): in the future allow Google version to use
    405   // flags, and only use this mechanism for open source
    406   auto* flags = UncheckedGlobalParsedModelFlags(false);
    407   string msg;
    408   bool model_success =
    409       toco::ParseModelFlagsFromCommandLineFlags(argc, argv, &msg, flags);
    410   if (!model_success || !msg.empty()) {
    411     // Log in non-standard way since this happens pre InitGoogle.
    412     fprintf(stderr, "%s", msg.c_str());
    413     fflush(stderr);
    414     abort();
    415   }
    416 }
    417 
    418 }  // namespace toco
    419