1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Python command line interface for running TOCO.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import argparse 22 import os 23 import sys 24 25 from tensorflow.lite.python import lite 26 from tensorflow.lite.python import lite_constants 27 from tensorflow.lite.toco import toco_flags_pb2 as _toco_flags_pb2 28 from tensorflow.python import tf2 29 from tensorflow.python.platform import app 30 31 32 def _parse_array(values, type_fn=str): 33 if values is not None: 34 return [type_fn(val) for val in values.split(",") if val] 35 return None 36 37 38 def _parse_set(values): 39 if values is not None: 40 return set([item for item in values.split(",") if item]) 41 return None 42 43 44 def _parse_inference_type(value, flag): 45 """Converts the inference type to the value of the constant. 46 47 Args: 48 value: str representing the inference type. 49 flag: str representing the flag name. 50 51 Returns: 52 tf.dtype. 53 54 Raises: 55 ValueError: Unsupported value. 56 """ 57 if value == "FLOAT": 58 return lite_constants.FLOAT 59 if value == "QUANTIZED_UINT8": 60 return lite_constants.QUANTIZED_UINT8 61 raise ValueError("Unsupported value for --{0}. Only FLOAT and " 62 "QUANTIZED_UINT8 are supported.".format(flag)) 63 64 65 def _get_toco_converter(flags): 66 """Makes a TFLiteConverter object based on the flags provided. 67 68 Args: 69 flags: argparse.Namespace object containing TFLite flags. 70 71 Returns: 72 TFLiteConverter object. 73 74 Raises: 75 ValueError: Invalid flags. 76 """ 77 # Parse input and output arrays. 78 input_arrays = _parse_array(flags.input_arrays) 79 input_shapes = None 80 if flags.input_shapes: 81 input_shapes_list = [ 82 _parse_array(shape, type_fn=int) 83 for shape in flags.input_shapes.split(":") 84 ] 85 input_shapes = dict(zip(input_arrays, input_shapes_list)) 86 output_arrays = _parse_array(flags.output_arrays) 87 88 converter_kwargs = { 89 "input_arrays": input_arrays, 90 "input_shapes": input_shapes, 91 "output_arrays": output_arrays 92 } 93 94 # Create TFLiteConverter. 95 if flags.graph_def_file: 96 converter_fn = lite.TFLiteConverter.from_frozen_graph 97 converter_kwargs["graph_def_file"] = flags.graph_def_file 98 elif flags.saved_model_dir: 99 converter_fn = lite.TFLiteConverter.from_saved_model 100 converter_kwargs["saved_model_dir"] = flags.saved_model_dir 101 converter_kwargs["tag_set"] = _parse_set(flags.saved_model_tag_set) 102 converter_kwargs["signature_key"] = flags.saved_model_signature_key 103 elif flags.keras_model_file: 104 converter_fn = lite.TFLiteConverter.from_keras_model_file 105 converter_kwargs["model_file"] = flags.keras_model_file 106 else: 107 raise ValueError("--graph_def_file, --saved_model_dir, or " 108 "--keras_model_file must be specified.") 109 110 return converter_fn(**converter_kwargs) 111 112 113 def _convert_model(flags): 114 """Calls function to convert the TensorFlow model into a TFLite model. 115 116 Args: 117 flags: argparse.Namespace object. 118 119 Raises: 120 ValueError: Invalid flags. 121 """ 122 # Create converter. 123 converter = _get_toco_converter(flags) 124 if flags.inference_type: 125 converter.inference_type = _parse_inference_type(flags.inference_type, 126 "inference_type") 127 if flags.inference_input_type: 128 converter.inference_input_type = _parse_inference_type( 129 flags.inference_input_type, "inference_input_type") 130 if flags.output_format: 131 converter.output_format = _toco_flags_pb2.FileFormat.Value( 132 flags.output_format) 133 134 if flags.mean_values and flags.std_dev_values: 135 input_arrays = converter.get_input_arrays() 136 std_dev_values = _parse_array(flags.std_dev_values, type_fn=float) 137 138 # In quantized inference, mean_value has to be integer so that the real 139 # value 0.0 is exactly representable. 140 if converter.inference_type == lite_constants.QUANTIZED_UINT8: 141 mean_values = _parse_array(flags.mean_values, type_fn=int) 142 else: 143 mean_values = _parse_array(flags.mean_values, type_fn=float) 144 quant_stats = list(zip(mean_values, std_dev_values)) 145 if ((not flags.input_arrays and len(input_arrays) > 1) or 146 (len(input_arrays) != len(quant_stats))): 147 raise ValueError("Mismatching --input_arrays, --std_dev_values, and " 148 "--mean_values. The flags must have the same number of " 149 "items. The current input arrays are '{0}'. " 150 "--input_arrays must be present when specifying " 151 "--std_dev_values and --mean_values with multiple input " 152 "tensors in order to map between names and " 153 "values.".format(",".join(input_arrays))) 154 converter.quantized_input_stats = dict(zip(input_arrays, quant_stats)) 155 if (flags.default_ranges_min is not None) and (flags.default_ranges_max is 156 not None): 157 converter.default_ranges_stats = (flags.default_ranges_min, 158 flags.default_ranges_max) 159 160 if flags.drop_control_dependency: 161 converter.drop_control_dependency = flags.drop_control_dependency 162 if flags.reorder_across_fake_quant: 163 converter.reorder_across_fake_quant = flags.reorder_across_fake_quant 164 if flags.change_concat_input_ranges: 165 converter.change_concat_input_ranges = ( 166 flags.change_concat_input_ranges == "TRUE") 167 168 if flags.allow_custom_ops: 169 converter.allow_custom_ops = flags.allow_custom_ops 170 if flags.target_ops: 171 ops_set_options = lite.OpsSet.get_options() 172 converter.target_ops = set() 173 for option in flags.target_ops.split(","): 174 if option not in ops_set_options: 175 raise ValueError("Invalid value for --target_ops. Options: " 176 "{0}".format(",".join(ops_set_options))) 177 converter.target_ops.add(lite.OpsSet(option)) 178 179 if flags.post_training_quantize: 180 converter.post_training_quantize = flags.post_training_quantize 181 if converter.inference_type == lite_constants.QUANTIZED_UINT8: 182 print("--post_training_quantize quantizes a graph of inference_type " 183 "FLOAT. Overriding inference type QUANTIZED_UINT8 to FLOAT.") 184 converter.inference_type = lite_constants.FLOAT 185 186 if flags.dump_graphviz_dir: 187 converter.dump_graphviz_dir = flags.dump_graphviz_dir 188 if flags.dump_graphviz_video: 189 converter.dump_graphviz_vode = flags.dump_graphviz_video 190 191 # Convert model. 192 output_data = converter.convert() 193 with open(flags.output_file, "wb") as f: 194 f.write(output_data) 195 196 197 def _check_flags(flags, unparsed): 198 """Checks the parsed and unparsed flags to ensure they are valid. 199 200 Raises an error if previously support unparsed flags are found. Raises an 201 error for parsed flags that don't meet the required conditions. 202 203 Args: 204 flags: argparse.Namespace object containing TFLite flags. 205 unparsed: List of unparsed flags. 206 207 Raises: 208 ValueError: Invalid flags. 209 """ 210 211 # Check unparsed flags for common mistakes based on previous TOCO. 212 def _get_message_unparsed(flag, orig_flag, new_flag): 213 if flag.startswith(orig_flag): 214 return "\n Use {0} instead of {1}".format(new_flag, orig_flag) 215 return "" 216 217 if unparsed: 218 output = "" 219 for flag in unparsed: 220 output += _get_message_unparsed(flag, "--input_file", "--graph_def_file") 221 output += _get_message_unparsed(flag, "--savedmodel_directory", 222 "--saved_model_dir") 223 output += _get_message_unparsed(flag, "--std_value", "--std_dev_values") 224 output += _get_message_unparsed(flag, "--batch_size", "--input_shapes") 225 output += _get_message_unparsed(flag, "--dump_graphviz", 226 "--dump_graphviz_dir") 227 if output: 228 raise ValueError(output) 229 230 # Check that flags are valid. 231 if flags.graph_def_file and (not flags.input_arrays or 232 not flags.output_arrays): 233 raise ValueError("--input_arrays and --output_arrays are required with " 234 "--graph_def_file") 235 236 if flags.input_shapes: 237 if not flags.input_arrays: 238 raise ValueError("--input_shapes must be used with --input_arrays") 239 if flags.input_shapes.count(":") != flags.input_arrays.count(","): 240 raise ValueError("--input_shapes and --input_arrays must have the same " 241 "number of items") 242 243 if flags.std_dev_values or flags.mean_values: 244 if bool(flags.std_dev_values) != bool(flags.mean_values): 245 raise ValueError("--std_dev_values and --mean_values must be used " 246 "together") 247 if flags.std_dev_values.count(",") != flags.mean_values.count(","): 248 raise ValueError("--std_dev_values, --mean_values must have the same " 249 "number of items") 250 251 if (flags.default_ranges_min is None) != (flags.default_ranges_max is None): 252 raise ValueError("--default_ranges_min and --default_ranges_max must be " 253 "used together") 254 255 if flags.dump_graphviz_video and not flags.dump_graphviz_dir: 256 raise ValueError("--dump_graphviz_video must be used with " 257 "--dump_graphviz_dir") 258 259 260 def run_main(_): 261 """Main in toco_convert.py.""" 262 if tf2.enabled(): 263 raise ValueError("tflite_convert is currently unsupported in 2.0. " 264 "Please use the Python API " 265 "tf.lite.TFLiteConverter.from_concrete_function().") 266 267 parser = argparse.ArgumentParser( 268 description=("Command line tool to run TensorFlow Lite Optimizing " 269 "Converter (TOCO).")) 270 271 # Output file flag. 272 parser.add_argument( 273 "--output_file", 274 type=str, 275 help="Full filepath of the output file.", 276 required=True) 277 278 # Input file flags. 279 input_file_group = parser.add_mutually_exclusive_group(required=True) 280 input_file_group.add_argument( 281 "--graph_def_file", 282 type=str, 283 help="Full filepath of file containing frozen TensorFlow GraphDef.") 284 input_file_group.add_argument( 285 "--saved_model_dir", 286 type=str, 287 help="Full filepath of directory containing the SavedModel.") 288 input_file_group.add_argument( 289 "--keras_model_file", 290 type=str, 291 help="Full filepath of HDF5 file containing tf.Keras model.") 292 293 # Model format flags. 294 parser.add_argument( 295 "--output_format", 296 type=str.upper, 297 choices=["TFLITE", "GRAPHVIZ_DOT"], 298 help="Output file format.") 299 parser.add_argument( 300 "--inference_type", 301 type=str.upper, 302 choices=["FLOAT", "QUANTIZED_UINT8"], 303 help="Target data type of real-number arrays in the output file.") 304 parser.add_argument( 305 "--inference_input_type", 306 type=str.upper, 307 choices=["FLOAT", "QUANTIZED_UINT8"], 308 help=("Target data type of real-number input arrays. Allows for a " 309 "different type for input arrays in the case of quantization.")) 310 311 # Input and output arrays flags. 312 parser.add_argument( 313 "--input_arrays", 314 type=str, 315 help="Names of the input arrays, comma-separated.") 316 parser.add_argument( 317 "--input_shapes", 318 type=str, 319 help="Shapes corresponding to --input_arrays, colon-separated.") 320 parser.add_argument( 321 "--output_arrays", 322 type=str, 323 help="Names of the output arrays, comma-separated.") 324 325 # SavedModel related flags. 326 parser.add_argument( 327 "--saved_model_tag_set", 328 type=str, 329 help=("Comma-separated set of tags identifying the MetaGraphDef within " 330 "the SavedModel to analyze. All tags must be present. In order to " 331 "pass in an empty tag set, pass in \"\". (default \"serve\")")) 332 parser.add_argument( 333 "--saved_model_signature_key", 334 type=str, 335 help=("Key identifying the SignatureDef containing inputs and outputs. " 336 "(default DEFAULT_SERVING_SIGNATURE_DEF_KEY)")) 337 338 # Quantization flags. 339 parser.add_argument( 340 "--std_dev_values", 341 type=str, 342 help=("Standard deviation of training data for each input tensor, " 343 "comma-separated floats. Used for quantized input tensors. " 344 "(default None)")) 345 parser.add_argument( 346 "--mean_values", 347 type=str, 348 help=("Mean of training data for each input tensor, comma-separated " 349 "floats. Used for quantized input tensors. (default None)")) 350 parser.add_argument( 351 "--default_ranges_min", 352 type=float, 353 help=("Default value for min bound of min/max range values used for all " 354 "arrays without a specified range, Intended for experimenting with " 355 "quantization via \"dummy quantization\". (default None)")) 356 parser.add_argument( 357 "--default_ranges_max", 358 type=float, 359 help=("Default value for max bound of min/max range values used for all " 360 "arrays without a specified range, Intended for experimenting with " 361 "quantization via \"dummy quantization\". (default None)")) 362 # quantize_weights is DEPRECATED. 363 parser.add_argument( 364 "--quantize_weights", 365 dest="post_training_quantize", 366 action="store_true", 367 help=argparse.SUPPRESS) 368 parser.add_argument( 369 "--post_training_quantize", 370 dest="post_training_quantize", 371 action="store_true", 372 help=( 373 "Boolean indicating whether to quantize the weights of the " 374 "converted float model. Model size will be reduced and there will " 375 "be latency improvements (at the cost of accuracy). (default False)")) 376 377 # Graph manipulation flags. 378 parser.add_argument( 379 "--drop_control_dependency", 380 action="store_true", 381 help=("Boolean indicating whether to drop control dependencies silently. " 382 "This is due to TensorFlow not supporting control dependencies. " 383 "(default True)")) 384 parser.add_argument( 385 "--reorder_across_fake_quant", 386 action="store_true", 387 help=("Boolean indicating whether to reorder FakeQuant nodes in " 388 "unexpected locations. Used when the location of the FakeQuant " 389 "nodes is preventing graph transformations necessary to convert " 390 "the graph. Results in a graph that differs from the quantized " 391 "training graph, potentially causing differing arithmetic " 392 "behavior. (default False)")) 393 # Usage for this flag is --change_concat_input_ranges=true or 394 # --change_concat_input_ranges=false in order to make it clear what the flag 395 # is set to. This keeps the usage consistent with other usages of the flag 396 # where the default is different. The default value here is False. 397 parser.add_argument( 398 "--change_concat_input_ranges", 399 type=str.upper, 400 choices=["TRUE", "FALSE"], 401 help=("Boolean to change behavior of min/max ranges for inputs and " 402 "outputs of the concat operator for quantized models. Changes the " 403 "ranges of concat operator overlap when true. (default False)")) 404 405 # Permitted ops flags. 406 parser.add_argument( 407 "--allow_custom_ops", 408 action="store_true", 409 help=("Boolean indicating whether to allow custom operations. When false " 410 "any unknown operation is an error. When true, custom ops are " 411 "created for any op that is unknown. The developer will need to " 412 "provide these to the TensorFlow Lite runtime with a custom " 413 "resolver. (default False)")) 414 parser.add_argument( 415 "--target_ops", 416 type=str, 417 help=("Experimental flag, subject to change. Set of OpsSet options " 418 "indicating which converter to use. Options: {0}. One or more " 419 "option may be specified. (default set([OpsSet.TFLITE_BUILTINS]))" 420 "".format(",".join(lite.OpsSet.get_options())))) 421 422 # Logging flags. 423 parser.add_argument( 424 "--dump_graphviz_dir", 425 type=str, 426 help=("Full filepath of folder to dump the graphs at various stages of " 427 "processing GraphViz .dot files. Preferred over --output_format=" 428 "GRAPHVIZ_DOT in order to keep the requirements of the output " 429 "file.")) 430 parser.add_argument( 431 "--dump_graphviz_video", 432 action="store_true", 433 help=("Boolean indicating whether to dump the graph after every graph " 434 "transformation")) 435 436 tflite_flags, unparsed = parser.parse_known_args(args=sys.argv[1:]) 437 try: 438 _check_flags(tflite_flags, unparsed) 439 except ValueError as e: 440 parser.print_usage() 441 file_name = os.path.basename(sys.argv[0]) 442 sys.stderr.write("{0}: error: {1}\n".format(file_name, str(e))) 443 sys.exit(1) 444 _convert_model(tflite_flags) 445 446 447 def main(): 448 app.run(main=run_main, argv=sys.argv[:1]) 449 450 451 if __name__ == "__main__": 452 main() 453