Home | History | Annotate | Download | only in python
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Python command line interface for running TOCO."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import argparse
     22 import os
     23 import sys
     24 
     25 from tensorflow.lite.python import lite
     26 from tensorflow.lite.python import lite_constants
     27 from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2
     28 from tensorflow.python import tf2
     29 from tensorflow.python.platform import app
     30 
     31 
     32 def _parse_array(values, type_fn=str):
     33   if values is not None:
     34     return [type_fn(val) for val in values.split(",") if val]
     35   return None
     36 
     37 
     38 def _parse_set(values):
     39   if values is not None:
     40     return set([item for item in values.split(",") if item])
     41   return None
     42 
     43 
     44 def _parse_inference_type(value, flag):
     45   """Converts the inference type to the value of the constant.
     46 
     47   Args:
     48     value: str representing the inference type.
     49     flag: str representing the flag name.
     50 
     51   Returns:
     52     tf.dtype.
     53 
     54   Raises:
     55     ValueError: Unsupported value.
     56   """
     57   if value == "FLOAT":
     58     return lite_constants.FLOAT
     59   if value == "QUANTIZED_UINT8":
     60     return lite_constants.QUANTIZED_UINT8
     61   raise ValueError("Unsupported value for --{0}. Only FLOAT and "
     62                    "QUANTIZED_UINT8 are supported.".format(flag))
     63 
     64 
     65 def _get_toco_converter(flags):
     66   """Makes a TFLiteConverter object based on the flags provided.
     67 
     68   Args:
     69     flags: argparse.Namespace object containing TFLite flags.
     70 
     71   Returns:
     72     TFLiteConverter object.
     73 
     74   Raises:
     75     ValueError: Invalid flags.
     76   """
     77   # Parse input and output arrays.
     78   input_arrays = _parse_array(flags.input_arrays)
     79   input_shapes = None
     80   if flags.input_shapes:
     81     input_shapes_list = [
     82         _parse_array(shape, type_fn=int)
     83         for shape in flags.input_shapes.split(":")
     84     ]
     85     input_shapes = dict(zip(input_arrays, input_shapes_list))
     86   output_arrays = _parse_array(flags.output_arrays)
     87 
     88   converter_kwargs = {
     89       "input_arrays": input_arrays,
     90       "input_shapes": input_shapes,
     91       "output_arrays": output_arrays
     92   }
     93 
     94   # Create TFLiteConverter.
     95   if flags.graph_def_file:
     96     converter_fn = lite.TFLiteConverter.from_frozen_graph
     97     converter_kwargs["graph_def_file"] = flags.graph_def_file
     98   elif flags.saved_model_dir:
     99     converter_fn = lite.TFLiteConverter.from_saved_model
    100     converter_kwargs["saved_model_dir"] = flags.saved_model_dir
    101     converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set)
    102     converter_kwargs["signature_key"] = flags.saved_model_signature_key
    103   elif flags.keras_model_file:
    104     converter_fn = lite.TFLiteConverter.from_keras_model_file
    105     converter_kwargs["model_file"] = flags.keras_model_file
    106   else:
    107     raise ValueError("--graph_def_file, --saved_model_dir, or "
    108                      "--keras_model_file must be specified.")
    109 
    110   return converter_fn(**converter_kwargs)
    111 
    112 
    113 def _convert_model(flags):
    114   """Calls function to convert the TensorFlow model into a TFLite model.
    115 
    116   Args:
    117     flags: argparse.Namespace object.
    118 
    119   Raises:
    120     ValueError: Invalid flags.
    121   """
    122   # Create converter.
    123   converter = _get_toco_converter(flags)
    124   if flags.inference_type:
    125     converter.inference_type = _parse_inference_type(flags.inference_type,
    126                                                      "inference_type")
    127   if flags.inference_input_type:
    128     converter.inference_input_type = _parse_inference_type(
    129         flags.inference_input_type, "inference_input_type")
    130   if flags.output_format:
    131     converter.output_format = _toco_flags_pb2.FileFormat.Value(
    132         flags.output_format)
    133 
    134   if flags.mean_values and flags.std_dev_values:
    135     input_arrays = converter.get_input_arrays()
    136     std_dev_values = _parse_array(flags.std_dev_values, type_fn=float)
    137 
    138     # In quantized inference, mean_value has to be integer so that the real
    139     # value 0.0 is exactly representable.
    140     if converter.inference_type == lite_constants.QUANTIZED_UINT8:
    141       mean_values = _parse_array(flags.mean_values, type_fn=int)
    142     else:
    143       mean_values = _parse_array(flags.mean_values, type_fn=float)
    144     quant_stats = list(zip(mean_values, std_dev_values))
    145     if ((not flags.input_arrays and len(input_arrays) > 1) or
    146         (len(input_arrays) != len(quant_stats))):
    147       raise ValueError("Mismatching --input_arrays, --std_dev_values, and "
    148                        "--mean_values. The flags must have the same number of "
    149                        "items. The current input arrays are '{0}'. "
    150                        "--input_arrays must be present when specifying "
    151                        "--std_dev_values and --mean_values with multiple input "
    152                        "tensors in order to map between names and "
    153                        "values.".format(",".join(input_arrays)))
    154     converter.quantized_input_stats = dict(zip(input_arrays, quant_stats))
    155   if (flags.default_ranges_min is not None) and (flags.default_ranges_max is
    156                                                  not None):
    157     converter.default_ranges_stats = (flags.default_ranges_min,
    158                                       flags.default_ranges_max)
    159 
    160   if flags.drop_control_dependency:
    161     converter.drop_control_dependency = flags.drop_control_dependency
    162   if flags.reorder_across_fake_quant:
    163     converter.reorder_across_fake_quant = flags.reorder_across_fake_quant
    164   if flags.change_concat_input_ranges:
    165     converter.change_concat_input_ranges = (
    166         flags.change_concat_input_ranges == "TRUE")
    167 
    168   if flags.allow_custom_ops:
    169     converter.allow_custom_ops = flags.allow_custom_ops
    170   if flags.target_ops:
    171     ops_set_options = lite.OpsSet.get_options()
    172     converter.target_ops = set()
    173     for option in flags.target_ops.split(","):
    174       if option not in ops_set_options:
    175         raise ValueError("Invalid value for --target_ops. Options: "
    176                          "{0}".format(",".join(ops_set_options)))
    177       converter.target_ops.add(lite.OpsSet(option))
    178 
    179   if flags.post_training_quantize:
    180     converter.post_training_quantize = flags.post_training_quantize
    181     if converter.inference_type == lite_constants.QUANTIZED_UINT8:
    182       print("--post_training_quantize quantizes a graph of inference_type "
    183             "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.")
    184       converter.inference_type = lite_constants.FLOAT
    185 
    186   if flags.dump_graphviz_dir:
    187     converter.dump_graphviz_dir = flags.dump_graphviz_dir
    188   if flags.dump_graphviz_video:
    189     converter.dump_graphviz_vode = flags.dump_graphviz_video
    190 
    191   # Convert model.
    192   output_data = converter.convert()
    193   with open(flags.output_file, "wb") as f:
    194     f.write(output_data)
    195 
    196 
    197 def _check_flags(flags, unparsed):
    198   """Checks the parsed and unparsed flags to ensure they are valid.
    199 
    200   Raises an error if previously support unparsed flags are found. Raises an
    201   error for parsed flags that don't meet the required conditions.
    202 
    203   Args:
    204     flags: argparse.Namespace object containing TFLite flags.
    205     unparsed: List of unparsed flags.
    206 
    207   Raises:
    208     ValueError: Invalid flags.
    209   """
    210 
    211   # Check unparsed flags for common mistakes based on previous TOCO.
    212   def _get_message_unparsed(flag, orig_flag, new_flag):
    213     if flag.startswith(orig_flag):
    214       return "\n  Use {0} instead of {1}".format(new_flag, orig_flag)
    215     return ""
    216 
    217   if unparsed:
    218     output = ""
    219     for flag in unparsed:
    220       output += _get_message_unparsed(flag, "--input_file", "--graph_def_file")
    221       output += _get_message_unparsed(flag, "--savedmodel_directory",
    222                                       "--saved_model_dir")
    223       output += _get_message_unparsed(flag, "--std_value", "--std_dev_values")
    224       output += _get_message_unparsed(flag, "--batch_size", "--input_shapes")
    225       output += _get_message_unparsed(flag, "--dump_graphviz",
    226                                       "--dump_graphviz_dir")
    227     if output:
    228       raise ValueError(output)
    229 
    230   # Check that flags are valid.
    231   if flags.graph_def_file and (not flags.input_arrays or
    232                                not flags.output_arrays):
    233     raise ValueError("--input_arrays and --output_arrays are required with "
    234                      "--graph_def_file")
    235 
    236   if flags.input_shapes:
    237     if not flags.input_arrays:
    238       raise ValueError("--input_shapes must be used with --input_arrays")
    239     if flags.input_shapes.count(":") != flags.input_arrays.count(","):
    240       raise ValueError("--input_shapes and --input_arrays must have the same "
    241                        "number of items")
    242 
    243   if flags.std_dev_values or flags.mean_values:
    244     if bool(flags.std_dev_values) != bool(flags.mean_values):
    245       raise ValueError("--std_dev_values and --mean_values must be used "
    246                        "together")
    247     if flags.std_dev_values.count(",") != flags.mean_values.count(","):
    248       raise ValueError("--std_dev_values, --mean_values must have the same "
    249                        "number of items")
    250 
    251   if (flags.default_ranges_min is None) != (flags.default_ranges_max is None):
    252     raise ValueError("--default_ranges_min and --default_ranges_max must be "
    253                      "used together")
    254 
    255   if flags.dump_graphviz_video and not flags.dump_graphviz_dir:
    256     raise ValueError("--dump_graphviz_video must be used with "
    257                      "--dump_graphviz_dir")
    258 
    259 
    260 def run_main(_):
    261   """Main in toco_convert.py."""
    262   if tf2.enabled():
    263     raise ValueError("tflite_convert is currently unsupported in 2.0. "
    264                      "Please use the Python API "
    265                      "tf.lite.TFLiteConverter.from_concrete_function().")
    266 
    267   parser = argparse.ArgumentParser(
    268       description=("Command line tool to run TensorFlow Lite Optimizing "
    269                    "Converter (TOCO)."))
    270 
    271   # Output file flag.
    272   parser.add_argument(
    273       "--output_file",
    274       type=str,
    275       help="Full filepath of the output file.",
    276       required=True)
    277 
    278   # Input file flags.
    279   input_file_group = parser.add_mutually_exclusive_group(required=True)
    280   input_file_group.add_argument(
    281       "--graph_def_file",
    282       type=str,
    283       help="Full filepath of file containing frozen TensorFlow GraphDef.")
    284   input_file_group.add_argument(
    285       "--saved_model_dir",
    286       type=str,
    287       help="Full filepath of directory containing the SavedModel.")
    288   input_file_group.add_argument(
    289       "--keras_model_file",
    290       type=str,
    291       help="Full filepath of HDF5 file containing tf.Keras model.")
    292 
    293   # Model format flags.
    294   parser.add_argument(
    295       "--output_format",
    296       type=str.upper,
    297       choices=["TFLITE", "GRAPHVIZ_DOT"],
    298       help="Output file format.")
    299   parser.add_argument(
    300       "--inference_type",
    301       type=str.upper,
    302       choices=["FLOAT", "QUANTIZED_UINT8"],
    303       help="Target data type of real-number arrays in the output file.")
    304   parser.add_argument(
    305       "--inference_input_type",
    306       type=str.upper,
    307       choices=["FLOAT", "QUANTIZED_UINT8"],
    308       help=("Target data type of real-number input arrays. Allows for a "
    309             "different type for input arrays in the case of quantization."))
    310 
    311   # Input and output arrays flags.
    312   parser.add_argument(
    313       "--input_arrays",
    314       type=str,
    315       help="Names of the input arrays, comma-separated.")
    316   parser.add_argument(
    317       "--input_shapes",
    318       type=str,
    319       help="Shapes corresponding to --input_arrays, colon-separated.")
    320   parser.add_argument(
    321       "--output_arrays",
    322       type=str,
    323       help="Names of the output arrays, comma-separated.")
    324 
    325   # SavedModel related flags.
    326   parser.add_argument(
    327       "--saved_model_tag_set",
    328       type=str,
    329       help=("Comma-separated set of tags identifying the MetaGraphDef within "
    330             "the SavedModel to analyze. All tags must be present. In order to "
    331             "pass in an empty tag set, pass in \"\". (default \"serve\")"))
    332   parser.add_argument(
    333       "--saved_model_signature_key",
    334       type=str,
    335       help=("Key identifying the SignatureDef containing inputs and outputs. "
    336             "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)"))
    337 
    338   # Quantization flags.
    339   parser.add_argument(
    340       "--std_dev_values",
    341       type=str,
    342       help=("Standard deviation of training data for each input tensor, "
    343             "comma-separated floats. Used for quantized input tensors. "
    344             "(default None)"))
    345   parser.add_argument(
    346       "--mean_values",
    347       type=str,
    348       help=("Mean of training data for each input tensor, comma-separated "
    349             "floats. Used for quantized input tensors. (default None)"))
    350   parser.add_argument(
    351       "--default_ranges_min",
    352       type=float,
    353       help=("Default value for min bound of min/max range values used for all "
    354             "arrays without a specified range, Intended for experimenting with "
    355             "quantization via \"dummy quantization\". (default None)"))
    356   parser.add_argument(
    357       "--default_ranges_max",
    358       type=float,
    359       help=("Default value for max bound of min/max range values used for all "
    360             "arrays without a specified range, Intended for experimenting with "
    361             "quantization via \"dummy quantization\". (default None)"))
    362   # quantize_weights is DEPRECATED.
    363   parser.add_argument(
    364       "--quantize_weights",
    365       dest="post_training_quantize",
    366       action="store_true",
    367       help=argparse.SUPPRESS)
    368   parser.add_argument(
    369       "--post_training_quantize",
    370       dest="post_training_quantize",
    371       action="store_true",
    372       help=(
    373           "Boolean indicating whether to quantize the weights of the "
    374           "converted float model. Model size will be reduced and there will "
    375           "be latency improvements (at the cost of accuracy). (default False)"))
    376 
    377   # Graph manipulation flags.
    378   parser.add_argument(
    379       "--drop_control_dependency",
    380       action="store_true",
    381       help=("Boolean indicating whether to drop control dependencies silently. "
    382             "This is due to TensorFlow not supporting control dependencies. "
    383             "(default True)"))
    384   parser.add_argument(
    385       "--reorder_across_fake_quant",
    386       action="store_true",
    387       help=("Boolean indicating whether to reorder FakeQuant nodes in "
    388             "unexpected locations. Used when the location of the FakeQuant "
    389             "nodes is preventing graph transformations necessary to convert "
    390             "the graph. Results in a graph that differs from the quantized "
    391             "training graph, potentially causing differing arithmetic "
    392             "behavior. (default False)"))
    393   # Usage for this flag is --change_concat_input_ranges=true or
    394   # --change_concat_input_ranges=false in order to make it clear what the flag
    395   # is set to. This keeps the usage consistent with other usages of the flag
    396   # where the default is different. The default value here is False.
    397   parser.add_argument(
    398       "--change_concat_input_ranges",
    399       type=str.upper,
    400       choices=["TRUE", "FALSE"],
    401       help=("Boolean to change behavior of min/max ranges for inputs and "
    402             "outputs of the concat operator for quantized models. Changes the "
    403             "ranges of concat operator overlap when true. (default False)"))
    404 
    405   # Permitted ops flags.
    406   parser.add_argument(
    407       "--allow_custom_ops",
    408       action="store_true",
    409       help=("Boolean indicating whether to allow custom operations. When false "
    410             "any unknown operation is an error. When true, custom ops are "
    411             "created for any op that is unknown. The developer will need to "
    412             "provide these to the TensorFlow Lite runtime with a custom "
    413             "resolver. (default False)"))
    414   parser.add_argument(
    415       "--target_ops",
    416       type=str,
    417       help=("Experimental flag, subject to change. Set of OpsSet options "
    418             "indicating which converter to use. Options: {0}. One or more "
    419             "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))"
    420             "".format(",".join(lite.OpsSet.get_options()))))
    421 
    422   # Logging flags.
    423   parser.add_argument(
    424       "--dump_graphviz_dir",
    425       type=str,
    426       help=("Full filepath of folder to dump the graphs at various stages of "
    427             "processing GraphViz .dot files. Preferred over --output_format="
    428             "GRAPHVIZ_DOT in order to keep the requirements of the output "
    429             "file."))
    430   parser.add_argument(
    431       "--dump_graphviz_video",
    432       action="store_true",
    433       help=("Boolean indicating whether to dump the graph after every graph "
    434             "transformation"))
    435 
    436   tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:])
    437   try:
    438     _check_flags(tflite_flags, unparsed)
    439   except ValueError as e:
    440     parser.print_usage()
    441     file_name = os.path.basename(sys.argv[0])
    442     sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e)))
    443     sys.exit(1)
    444   _convert_model(tflite_flags)
    445 
    446 
    447 def main():
    448   app.run(main=run_main, argv=sys.argv[:1])
    449 
    450 
    451 if __name__ == "__main__":
    452   main()
    453