Home | History | Annotate | Download | only in toco
      1 /* Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 
      3 Licensed under the Apache License, Version 2.0 (the "License");
      4 you may not use this file except in compliance with the License.
      5 You may obtain a copy of the License at
      6 
      7     http://www.apache.org/licenses/LICENSE-2.0
      8 
      9 Unless required by applicable law or agreed to in writing, software
     10 distributed under the License is distributed on an "AS IS" BASIS,
     11 WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 See the License for the specific language governing permissions and
     13 limitations under the License.
     14 ==============================================================================*/
     15 
     16 #include <string>
     17 #include <vector>
     18 
     19 #include "absl/strings/numbers.h"
     20 #include "absl/strings/str_join.h"
     21 #include "absl/strings/str_split.h"
     22 #include "absl/strings/strip.h"
     23 #include "tensorflow/contrib/lite/toco/toco_cmdline_flags.h"
     24 #include "tensorflow/contrib/lite/toco/toco_port.h"
     25 #include "tensorflow/core/platform/logging.h"
     26 #include "tensorflow/core/util/command_line_flags.h"
     27 
     28 namespace toco {
     29 
     30 bool ParseTocoFlagsFromCommandLineFlags(
     31     int* argc, char* argv[], string* msg,
     32     ParsedTocoFlags* parsed_toco_flags_ptr) {
     33   using tensorflow::Flag;
     34   ParsedTocoFlags& parsed_flags = *parsed_toco_flags_ptr;
     35   std::vector<tensorflow::Flag> flags = {
     36       Flag("input_file", parsed_flags.input_file.bind(),
     37            parsed_flags.input_file.default_value(),
     38            "Input file (model of any supported format). For Protobuf "
     39            "formats, both text and binary are supported regardless of file "
     40            "extension."),
     41       Flag("output_file", parsed_flags.output_file.bind(),
     42            parsed_flags.output_file.default_value(),
     43            "Output file. "
     44            "For Protobuf formats, the binary format will be used."),
     45       Flag("input_format", parsed_flags.input_format.bind(),
     46            parsed_flags.input_format.default_value(),
     47            "Input file format. One of: TENSORFLOW_GRAPHDEF, TFLITE."),
     48       Flag("output_format", parsed_flags.output_format.bind(),
     49            parsed_flags.output_format.default_value(),
     50            "Output file format. "
     51            "One of TENSORFLOW_GRAPHDEF, TFLITE, GRAPHVIZ_DOT."),
     52       Flag("default_ranges_min", parsed_flags.default_ranges_min.bind(),
     53            parsed_flags.default_ranges_min.default_value(),
     54            "If defined, will be used as the default value for the min bound "
     55            "of min/max ranges used for quantization."),
     56       Flag("default_ranges_max", parsed_flags.default_ranges_max.bind(),
     57            parsed_flags.default_ranges_max.default_value(),
     58            "If defined, will be used as the default value for the max bound "
     59            "of min/max ranges used for quantization."),
     60       Flag("inference_type", parsed_flags.inference_type.bind(),
     61            parsed_flags.inference_type.default_value(),
     62            "Target data type of arrays in the output file (for input_arrays, "
     63            "this may be overridden by inference_input_type). "
     64            "One of FLOAT, QUANTIZED_UINT8."),
     65       Flag("inference_input_type", parsed_flags.inference_input_type.bind(),
     66            parsed_flags.inference_input_type.default_value(),
     67            "Target data type of input arrays. "
     68            "If not specified, inference_type is used. "
     69            "One of FLOAT, QUANTIZED_UINT8."),
     70       Flag("input_type", parsed_flags.input_type.bind(),
     71            parsed_flags.input_type.default_value(),
     72            "Deprecated ambiguous flag that set both --input_data_types and "
     73            "--inference_input_type."),
     74       Flag("input_types", parsed_flags.input_types.bind(),
     75            parsed_flags.input_types.default_value(),
     76            "Deprecated ambiguous flag that set both --input_data_types and "
     77            "--inference_input_type. Was meant to be a "
     78            "comma-separated list, but this was deprecated before "
     79            "multiple-input-types was ever properly supported."),
     80 
     81       Flag("drop_fake_quant", parsed_flags.drop_fake_quant.bind(),
     82            parsed_flags.drop_fake_quant.default_value(),
     83            "Ignore and discard FakeQuant nodes. For instance, to "
     84            "generate plain float code without fake-quantization from a "
     85            "quantized graph."),
     86       Flag(
     87           "reorder_across_fake_quant",
     88           parsed_flags.reorder_across_fake_quant.bind(),
     89           parsed_flags.reorder_across_fake_quant.default_value(),
     90           "Normally, FakeQuant nodes must be strict boundaries for graph "
     91           "transformations, in order to ensure that quantized inference has "
     92           "the exact same arithmetic behavior as quantized training --- which "
     93           "is the whole point of quantized training and of FakeQuant nodes in "
     94           "the first place. "
     95           "However, that entails subtle requirements on where exactly "
     96           "FakeQuant nodes must be placed in the graph. Some quantized graphs "
     97           "have FakeQuant nodes at unexpected locations, that prevent graph "
     98           "transformations that are necessary in order to generate inference "
     99           "code for these graphs. Such graphs should be fixed, but as a "
    100           "temporary work-around, setting this reorder_across_fake_quant flag "
    101           "allows TOCO to perform necessary graph transformaitons on them, "
    102           "at the cost of no longer faithfully matching inference and training "
    103           "arithmetic."),
    104       Flag("allow_custom_ops", parsed_flags.allow_custom_ops.bind(),
    105            parsed_flags.allow_custom_ops.default_value(),
    106            "If true, allow TOCO to create TF Lite Custom operators for all the "
    107            "unsupported TensorFlow ops."),
    108       Flag(
    109           "drop_control_dependency",
    110           parsed_flags.drop_control_dependency.bind(),
    111           parsed_flags.drop_control_dependency.default_value(),
    112           "If true, ignore control dependency requirements in input TensorFlow "
    113           "GraphDef. Otherwise an error will be raised upon control dependency "
    114           "inputs."),
    115   };
    116   bool asked_for_help =
    117       *argc == 2 && (!strcmp(argv[1], "--help") || !strcmp(argv[1], "-help"));
    118   if (asked_for_help) {
    119     *msg += tensorflow::Flags::Usage(argv[0], flags);
    120     return false;
    121   } else {
    122     return tensorflow::Flags::Parse(argc, argv, flags);
    123   }
    124 }
    125 
    126 void ReadTocoFlagsFromCommandLineFlags(const ParsedTocoFlags& parsed_toco_flags,
    127                                        TocoFlags* toco_flags) {
    128   namespace port = toco::port;
    129   port::CheckInitGoogleIsDone("InitGoogle is not done yet");
    130 
    131   enum class FlagRequirement { kNone, kMustBeSpecified, kMustNotBeSpecified };
    132 
    133 #define ENFORCE_FLAG_REQUIREMENT(name, requirement)                          \
    134   do {                                                                       \
    135     if (requirement == FlagRequirement::kMustBeSpecified) {                  \
    136       QCHECK(parsed_toco_flags.name.specified())                             \
    137           << "Missing required flag: " << #name;                             \
    138     }                                                                        \
    139     if (requirement == FlagRequirement::kMustNotBeSpecified) {               \
    140       QCHECK(!parsed_toco_flags.name.specified())                            \
    141           << "Given other flags, this flag should not have been specified: " \
    142           << #name;                                                          \
    143     }                                                                        \
    144   } while (false)
    145 #define READ_TOCO_FLAG(name, requirement)                     \
    146   ENFORCE_FLAG_REQUIREMENT(name, requirement);                \
    147   do {                                                        \
    148     if (parsed_toco_flags.name.specified()) {                 \
    149       toco_flags->set_##name(parsed_toco_flags.name.value()); \
    150     }                                                         \
    151   } while (false)
    152 
    153 #define PARSE_TOCO_FLAG(Type, name, requirement)               \
    154   ENFORCE_FLAG_REQUIREMENT(name, requirement);                 \
    155   do {                                                         \
    156     if (parsed_toco_flags.name.specified()) {                  \
    157       Type x;                                                  \
    158       QCHECK(Type##_Parse(parsed_toco_flags.name.value(), &x)) \
    159           << "Unrecognized " << #Type << " value "             \
    160           << parsed_toco_flags.name.value();                   \
    161       toco_flags->set_##name(x);                               \
    162     }                                                          \
    163   } while (false)
    164 
    165   PARSE_TOCO_FLAG(FileFormat, input_format, FlagRequirement::kMustBeSpecified);
    166   PARSE_TOCO_FLAG(FileFormat, output_format, FlagRequirement::kMustBeSpecified);
    167   PARSE_TOCO_FLAG(IODataType, inference_type, FlagRequirement::kNone);
    168   PARSE_TOCO_FLAG(IODataType, inference_input_type, FlagRequirement::kNone);
    169   READ_TOCO_FLAG(default_ranges_min, FlagRequirement::kNone);
    170   READ_TOCO_FLAG(default_ranges_max, FlagRequirement::kNone);
    171   READ_TOCO_FLAG(drop_fake_quant, FlagRequirement::kNone);
    172   READ_TOCO_FLAG(reorder_across_fake_quant, FlagRequirement::kNone);
    173   READ_TOCO_FLAG(allow_custom_ops, FlagRequirement::kNone);
    174   READ_TOCO_FLAG(drop_control_dependency, FlagRequirement::kNone);
    175 
    176   // Deprecated flag handling.
    177   if (parsed_toco_flags.input_type.specified()) {
    178     LOG(WARNING)
    179         << "--input_type is deprecated. It was an ambiguous flag that set both "
    180            "--input_data_types and --inference_input_type. If you are trying "
    181            "to complement the input file with information about the type of "
    182            "input arrays, use --input_data_type. If you are trying to control "
    183            "the quantization/dequantization of real-numbers input arrays in "
    184            "the output file, use --inference_input_type.";
    185     toco::IODataType input_type;
    186     QCHECK(toco::IODataType_Parse(parsed_toco_flags.input_type.value(),
    187                                   &input_type));
    188     toco_flags->set_inference_input_type(input_type);
    189   }
    190   if (parsed_toco_flags.input_types.specified()) {
    191     LOG(WARNING)
    192         << "--input_types is deprecated. It was an ambiguous flag that set "
    193            "both --input_data_types and --inference_input_type. If you are "
    194            "trying to complement the input file with information about the "
    195            "type of input arrays, use --input_data_type. If you are trying to "
    196            "control the quantization/dequantization of real-numbers input "
    197            "arrays in the output file, use --inference_input_type.";
    198     std::vector<string> input_types =
    199         absl::StrSplit(parsed_toco_flags.input_types.value(), ',');
    200     QCHECK(!input_types.empty());
    201     for (int i = 1; i < input_types.size(); i++) {
    202       QCHECK_EQ(input_types[i], input_types[0]);
    203     }
    204     toco::IODataType input_type;
    205     QCHECK(toco::IODataType_Parse(input_types[0], &input_type));
    206     toco_flags->set_inference_input_type(input_type);
    207   }
    208 
    209 #undef READ_TOCO_FLAG
    210 #undef PARSE_TOCO_FLAG
    211 }
    212 }  // namespace toco
    213