Home | History | Annotate | Download | only in tools
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Command-line interface to inspect and execute a graph in a SavedModel.
     16 
     17 For detailed usages and examples, please refer to:
     18 https://www.tensorflow.org/programmers_guide/saved_model_cli
     19 
     20 """
     21 
     22 from __future__ import absolute_import
     23 from __future__ import division
     24 from __future__ import print_function
     25 
     26 import argparse
     27 import os
     28 import re
     29 import sys
     30 import warnings
     31 
     32 import numpy as np
     33 
     34 from six import integer_types
     35 from tensorflow.contrib.saved_model.python.saved_model import reader
     36 from tensorflow.contrib.saved_model.python.saved_model import signature_def_utils
     37 from tensorflow.core.example import example_pb2
     38 from tensorflow.core.framework import types_pb2
     39 from tensorflow.python.client import session
     40 from tensorflow.python.debug.wrappers import local_cli_wrapper
     41 from tensorflow.python.framework import ops as ops_lib
     42 from tensorflow.python.platform import app  # pylint: disable=unused-import
     43 from tensorflow.python.saved_model import loader
     44 from tensorflow.python.tools import saved_model_utils
     45 
     46 
     47 def _show_tag_sets(saved_model_dir):
     48   """Prints the tag-sets stored in SavedModel directory.
     49 
     50   Prints all the tag-sets for MetaGraphs stored in SavedModel directory.
     51 
     52   Args:
     53     saved_model_dir: Directory containing the SavedModel to inspect.
     54   """
     55   tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
     56   print('The given SavedModel contains the following tag-sets:')
     57   for tag_set in sorted(tag_sets):
     58     print(', '.join(sorted(tag_set)))
     59 
     60 
     61 def _show_signature_def_map_keys(saved_model_dir, tag_set):
     62   """Prints the keys for each SignatureDef in the SignatureDef map.
     63 
     64   Prints the list of SignatureDef keys from the SignatureDef map specified by
     65   the given tag-set and SavedModel directory.
     66 
     67   Args:
     68     saved_model_dir: Directory containing the SavedModel to inspect.
     69     tag_set: Group of tag(s) of the MetaGraphDef to get SignatureDef map from,
     70         in string format, separated by ','. For tag-set contains multiple tags,
     71         all tags must be passed in.
     72   """
     73   signature_def_map = get_signature_def_map(saved_model_dir, tag_set)
     74   print('The given SavedModel MetaGraphDef contains SignatureDefs with the '
     75         'following keys:')
     76   for signature_def_key in sorted(signature_def_map.keys()):
     77     print('SignatureDef key: \"%s\"' % signature_def_key)
     78 
     79 
     80 def _get_inputs_tensor_info_from_meta_graph_def(meta_graph_def,
     81                                                 signature_def_key):
     82   """Gets TensorInfo for all inputs of the SignatureDef.
     83 
     84   Returns a dictionary that maps each input key to its TensorInfo for the given
     85   signature_def_key in the meta_graph_def
     86 
     87   Args:
     88     meta_graph_def: MetaGraphDef protocol buffer with the SignatureDef map to
     89         look up SignatureDef key.
     90     signature_def_key: A SignatureDef key string.
     91 
     92   Returns:
     93     A dictionary that maps input tensor keys to TensorInfos.
     94   """
     95   return signature_def_utils.get_signature_def_by_key(meta_graph_def,
     96                                                       signature_def_key).inputs
     97 
     98 
     99 def _get_outputs_tensor_info_from_meta_graph_def(meta_graph_def,
    100                                                  signature_def_key):
    101   """Gets TensorInfos for all outputs of the SignatureDef.
    102 
    103   Returns a dictionary that maps each output key to its TensorInfo for the given
    104   signature_def_key in the meta_graph_def.
    105 
    106   Args:
    107     meta_graph_def: MetaGraphDef protocol buffer with the SignatureDefmap to
    108     look up signature_def_key.
    109     signature_def_key: A SignatureDef key string.
    110 
    111   Returns:
    112     A dictionary that maps output tensor keys to TensorInfos.
    113   """
    114   return signature_def_utils.get_signature_def_by_key(meta_graph_def,
    115                                                       signature_def_key).outputs
    116 
    117 
    118 def _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key):
    119   """Prints input and output TensorInfos.
    120 
    121   Prints the details of input and output TensorInfos for the SignatureDef mapped
    122   by the given signature_def_key.
    123 
    124   Args:
    125     saved_model_dir: Directory containing the SavedModel to inspect.
    126     tag_set: Group of tag(s) of the MetaGraphDef, in string format, separated by
    127         ','. For tag-set contains multiple tags, all tags must be passed in.
    128     signature_def_key: A SignatureDef key string.
    129   """
    130   meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir,
    131                                                         tag_set)
    132   inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
    133       meta_graph_def, signature_def_key)
    134   outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
    135       meta_graph_def, signature_def_key)
    136 
    137   print('The given SavedModel SignatureDef contains the following input(s):')
    138   for input_key, input_tensor in sorted(inputs_tensor_info.items()):
    139     print('inputs[\'%s\'] tensor_info:' % input_key)
    140     _print_tensor_info(input_tensor)
    141 
    142   print('The given SavedModel SignatureDef contains the following output(s):')
    143   for output_key, output_tensor in sorted(outputs_tensor_info.items()):
    144     print('outputs[\'%s\'] tensor_info:' % output_key)
    145     _print_tensor_info(output_tensor)
    146 
    147   print('Method name is: %s' %
    148         meta_graph_def.signature_def[signature_def_key].method_name)
    149 
    150 
    151 def _print_tensor_info(tensor_info):
    152   """Prints details of the given tensor_info.
    153 
    154   Args:
    155     tensor_info: TensorInfo object to be printed.
    156   """
    157   print('    dtype: ' +
    158         {value: key
    159          for (key, value) in types_pb2.DataType.items()}[tensor_info.dtype])
    160   # Display shape as tuple.
    161   if tensor_info.tensor_shape.unknown_rank:
    162     shape = 'unknown_rank'
    163   else:
    164     dims = [str(dim.size) for dim in tensor_info.tensor_shape.dim]
    165     shape = ', '.join(dims)
    166     shape = '(' + shape + ')'
    167   print('    shape: ' + shape)
    168   print('    name: ' + tensor_info.name)
    169 
    170 
    171 def _show_all(saved_model_dir):
    172   """Prints tag-set, SignatureDef and Inputs/Outputs information in SavedModel.
    173 
    174   Prints all tag-set, SignatureDef and Inputs/Outputs information stored in
    175   SavedModel directory.
    176 
    177   Args:
    178     saved_model_dir: Directory containing the SavedModel to inspect.
    179   """
    180   tag_sets = reader.get_saved_model_tag_sets(saved_model_dir)
    181   for tag_set in sorted(tag_sets):
    182     tag_set = ', '.join(tag_set)
    183     print('\nMetaGraphDef with tag-set: \'' + tag_set +
    184           '\' contains the following SignatureDefs:')
    185 
    186     signature_def_map = get_signature_def_map(saved_model_dir, tag_set)
    187     for signature_def_key in sorted(signature_def_map.keys()):
    188       print('\nsignature_def[\'' + signature_def_key + '\']:')
    189       _show_inputs_outputs(saved_model_dir, tag_set, signature_def_key)
    190 
    191 
    192 def get_meta_graph_def(saved_model_dir, tag_set):
    193   """DEPRECATED: Use saved_model_utils.get_meta_graph_def instead.
    194 
    195   Gets MetaGraphDef from SavedModel. Returns the MetaGraphDef for the given
    196   tag-set and SavedModel directory.
    197 
    198   Args:
    199     saved_model_dir: Directory containing the SavedModel to inspect or execute.
    200     tag_set: Group of tag(s) of the MetaGraphDef to load, in string format,
    201         separated by ','. For tag-set contains multiple tags, all tags must be
    202         passed in.
    203 
    204   Raises:
    205     RuntimeError: An error when the given tag-set does not exist in the
    206         SavedModel.
    207 
    208   Returns:
    209     A MetaGraphDef corresponding to the tag-set.
    210   """
    211   return saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)
    212 
    213 
    214 def get_signature_def_map(saved_model_dir, tag_set):
    215   """Gets SignatureDef map from a MetaGraphDef in a SavedModel.
    216 
    217   Returns the SignatureDef map for the given tag-set in the SavedModel
    218   directory.
    219 
    220   Args:
    221     saved_model_dir: Directory containing the SavedModel to inspect or execute.
    222     tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
    223         string format, separated by ','. For tag-set contains multiple tags, all
    224         tags must be passed in.
    225 
    226   Returns:
    227     A SignatureDef map that maps from string keys to SignatureDefs.
    228   """
    229   meta_graph = saved_model_utils.get_meta_graph_def(saved_model_dir, tag_set)
    230   return meta_graph.signature_def
    231 
    232 
    233 def run_saved_model_with_feed_dict(saved_model_dir, tag_set, signature_def_key,
    234                                    input_tensor_key_feed_dict, outdir,
    235                                    overwrite_flag, tf_debug=False):
    236   """Runs SavedModel and fetch all outputs.
    237 
    238   Runs the input dictionary through the MetaGraphDef within a SavedModel
    239   specified by the given tag_set and SignatureDef. Also save the outputs to file
    240   if outdir is not None.
    241 
    242   Args:
    243     saved_model_dir: Directory containing the SavedModel to execute.
    244     tag_set: Group of tag(s) of the MetaGraphDef with the SignatureDef map, in
    245         string format, separated by ','. For tag-set contains multiple tags, all
    246         tags must be passed in.
    247     signature_def_key: A SignatureDef key string.
    248     input_tensor_key_feed_dict: A dictionary maps input keys to numpy ndarrays.
    249     outdir: A directory to save the outputs to. If the directory doesn't exist,
    250         it will be created.
    251     overwrite_flag: A boolean flag to allow overwrite output file if file with
    252         the same name exists.
    253     tf_debug: A boolean flag to use TensorFlow Debugger (TFDBG) to observe the
    254         intermediate Tensor values and runtime GraphDefs while running the
    255         SavedModel.
    256 
    257   Raises:
    258     ValueError: When any of the input tensor keys is not valid.
    259     RuntimeError: An error when output file already exists and overwrite is not
    260     enabled.
    261   """
    262   # Get a list of output tensor names.
    263   meta_graph_def = saved_model_utils.get_meta_graph_def(saved_model_dir,
    264                                                         tag_set)
    265 
    266   # Re-create feed_dict based on input tensor name instead of key as session.run
    267   # uses tensor name.
    268   inputs_tensor_info = _get_inputs_tensor_info_from_meta_graph_def(
    269       meta_graph_def, signature_def_key)
    270 
    271   # Check if input tensor keys are valid.
    272   for input_key_name in input_tensor_key_feed_dict.keys():
    273     if input_key_name not in inputs_tensor_info.keys():
    274       raise ValueError(
    275           '"%s" is not a valid input key. Please choose from %s, or use '
    276           '--show option.' %
    277           (input_key_name, '"' + '", "'.join(inputs_tensor_info.keys()) + '"'))
    278 
    279   inputs_feed_dict = {
    280       inputs_tensor_info[key].name: tensor
    281       for key, tensor in input_tensor_key_feed_dict.items()
    282   }
    283   # Get outputs
    284   outputs_tensor_info = _get_outputs_tensor_info_from_meta_graph_def(
    285       meta_graph_def, signature_def_key)
    286   # Sort to preserve order because we need to go from value to key later.
    287   output_tensor_keys_sorted = sorted(outputs_tensor_info.keys())
    288   output_tensor_names_sorted = [
    289       outputs_tensor_info[tensor_key].name
    290       for tensor_key in output_tensor_keys_sorted
    291   ]
    292 
    293   with session.Session(graph=ops_lib.Graph()) as sess:
    294     loader.load(sess, tag_set.split(','), saved_model_dir)
    295 
    296     if tf_debug:
    297       sess = local_cli_wrapper.LocalCLIDebugWrapperSession(sess)
    298 
    299     outputs = sess.run(output_tensor_names_sorted, feed_dict=inputs_feed_dict)
    300 
    301     for i, output in enumerate(outputs):
    302       output_tensor_key = output_tensor_keys_sorted[i]
    303       print('Result for output key %s:\n%s' % (output_tensor_key, output))
    304 
    305       # Only save if outdir is specified.
    306       if outdir:
    307         # Create directory if outdir does not exist
    308         if not os.path.isdir(outdir):
    309           os.makedirs(outdir)
    310         output_full_path = os.path.join(outdir, output_tensor_key + '.npy')
    311 
    312         # If overwrite not enabled and file already exist, error out
    313         if not overwrite_flag and os.path.exists(output_full_path):
    314           raise RuntimeError(
    315               'Output file %s already exists. Add \"--overwrite\" to overwrite'
    316               ' the existing output files.' % output_full_path)
    317 
    318         np.save(output_full_path, output)
    319         print('Output %s is saved to %s' % (output_tensor_key,
    320                                             output_full_path))
    321 
    322 
    323 def preprocess_inputs_arg_string(inputs_str):
    324   """Parses input arg into dictionary that maps input to file/variable tuple.
    325 
    326   Parses input string in the format of, for example,
    327   "input1=filename1[variable_name1],input2=filename2" into a
    328   dictionary looks like
    329   {'input_key1': (filename1, variable_name1),
    330    'input_key2': (file2, None)}
    331   , which maps input keys to a tuple of file name and variable name(None if
    332   empty).
    333 
    334   Args:
    335     inputs_str: A string that specified where to load inputs. Inputs are
    336     separated by semicolons.
    337         * For each input key:
    338             '<input_key>=<filename>' or
    339             '<input_key>=<filename>[<variable_name>]'
    340         * The optional 'variable_name' key will be set to None if not specified.
    341 
    342   Returns:
    343     A dictionary that maps input keys to a tuple of file name and variable name.
    344 
    345   Raises:
    346     RuntimeError: An error when the given input string is in a bad format.
    347   """
    348   input_dict = {}
    349   inputs_raw = inputs_str.split(';')
    350   for input_raw in filter(bool, inputs_raw):  # skip empty strings
    351     # Format of input=filename[variable_name]'
    352     match = re.match(r'([^=]+)=([^\[\]]+)\[([^\[\]]+)\]$', input_raw)
    353 
    354     if match:
    355       input_dict[match.group(1)] = match.group(2), match.group(3)
    356     else:
    357       # Format of input=filename'
    358       match = re.match(r'([^=]+)=([^\[\]]+)$', input_raw)
    359       if match:
    360         input_dict[match.group(1)] = match.group(2), None
    361       else:
    362         raise RuntimeError(
    363             '--inputs "%s" format is incorrect. Please follow'
    364             '"<input_key>=<filename>", or'
    365             '"<input_key>=<filename>[<variable_name>]"' % input_raw)
    366 
    367   return input_dict
    368 
    369 
    370 def preprocess_input_exprs_arg_string(input_exprs_str):
    371   """Parses input arg into dictionary that maps input key to python expression.
    372 
    373   Parses input string in the format of 'input_key=<python expression>' into a
    374   dictionary that maps each input_key to its python expression.
    375 
    376   Args:
    377     input_exprs_str: A string that specifies python expression for input keys.
    378     Each input is separated by semicolon. For each input key:
    379         'input_key=<python expression>'
    380 
    381   Returns:
    382     A dictionary that maps input keys to their values.
    383 
    384   Raises:
    385     RuntimeError: An error when the given input string is in a bad format.
    386   """
    387   input_dict = {}
    388 
    389   for input_raw in filter(bool, input_exprs_str.split(';')):
    390     if '=' not in input_exprs_str:
    391       raise RuntimeError('--input_exprs "%s" format is incorrect. Please follow'
    392                          '"<input_key>=<python expression>"' % input_exprs_str)
    393     input_key, expr = input_raw.split('=', 1)
    394     # ast.literal_eval does not work with numpy expressions
    395     input_dict[input_key] = eval(expr)  # pylint: disable=eval-used
    396   return input_dict
    397 
    398 
    399 def preprocess_input_examples_arg_string(input_examples_str):
    400   """Parses input into dict that maps input keys to lists of tf.Example.
    401 
    402   Parses input string in the format of 'input_key1=[{feature_name:
    403   feature_list}];input_key2=[{feature_name:feature_list}];' into a dictionary
    404   that maps each input_key to its list of serialized tf.Example.
    405 
    406   Args:
    407     input_examples_str: A string that specifies a list of dictionaries of
    408     feature_names and their feature_lists for each input.
    409     Each input is separated by semicolon. For each input key:
    410       'input=[{feature_name1: feature_list1, feature_name2:feature_list2}]'
    411       items in feature_list can be the type of float, int, long or str.
    412 
    413   Returns:
    414     A dictionary that maps input keys to lists of serialized tf.Example.
    415 
    416   Raises:
    417     ValueError: An error when the given tf.Example is not a list.
    418   """
    419   input_dict = preprocess_input_exprs_arg_string(input_examples_str)
    420   for input_key, example_list in input_dict.items():
    421     if not isinstance(example_list, list):
    422       raise ValueError(
    423           'tf.Example input must be a list of dictionaries, but "%s" is %s' %
    424           (example_list, type(example_list)))
    425     input_dict[input_key] = [
    426         _create_example_string(example) for example in example_list
    427     ]
    428   return input_dict
    429 
    430 
    431 def _create_example_string(example_dict):
    432   """Create a serialized tf.example from feature dictionary."""
    433   example = example_pb2.Example()
    434   for feature_name, feature_list in example_dict.items():
    435     if not isinstance(feature_list, list):
    436       raise ValueError('feature value must be a list, but %s: "%s" is %s' %
    437                        (feature_name, feature_list, type(feature_list)))
    438     if isinstance(feature_list[0], float):
    439       example.features.feature[feature_name].float_list.value.extend(
    440           feature_list)
    441     elif isinstance(feature_list[0], str):
    442       example.features.feature[feature_name].bytes_list.value.extend(
    443           feature_list)
    444     elif isinstance(feature_list[0], integer_types):
    445       example.features.feature[feature_name].int64_list.value.extend(
    446           feature_list)
    447     else:
    448       raise ValueError(
    449           'Type %s for value %s is not supported for tf.train.Feature.' %
    450           (type(feature_list[0]), feature_list[0]))
    451   return example.SerializeToString()
    452 
    453 
    454 def load_inputs_from_input_arg_string(inputs_str, input_exprs_str,
    455                                       input_examples_str):
    456   """Parses input arg strings and create inputs feed_dict.
    457 
    458   Parses '--inputs' string for inputs to be loaded from file, and parses
    459   '--input_exprs' string for inputs to be evaluated from python expression.
    460   '--input_examples' string for inputs to be created from tf.example feature
    461   dictionary list.
    462 
    463   Args:
    464     inputs_str: A string that specified where to load inputs. Each input is
    465         separated by semicolon.
    466         * For each input key:
    467             '<input_key>=<filename>' or
    468             '<input_key>=<filename>[<variable_name>]'
    469         * The optional 'variable_name' key will be set to None if not specified.
    470         * File specified by 'filename' will be loaded using numpy.load. Inputs
    471             can be loaded from only .npy, .npz or pickle files.
    472         * The "[variable_name]" key is optional depending on the input file type
    473             as descripted in more details below.
    474         When loading from a npy file, which always contains a numpy ndarray, the
    475         content will be directly assigned to the specified input tensor. If a
    476         variable_name is specified, it will be ignored and a warning will be
    477         issued.
    478         When loading from a npz zip file, user can specify which variable within
    479         the zip file to load for the input tensor inside the square brackets. If
    480         nothing is specified, this function will check that only one file is
    481         included in the zip and load it for the specified input tensor.
    482         When loading from a pickle file, if no variable_name is specified in the
    483         square brackets, whatever that is inside the pickle file will be passed
    484         to the specified input tensor, else SavedModel CLI will assume a
    485         dictionary is stored in the pickle file and the value corresponding to
    486         the variable_name will be used.
    487     input_exprs_str: A string that specifies python expressions for inputs.
    488         * In the format of: '<input_key>=<python expression>'.
    489         * numpy module is available as np.
    490     input_examples_str: A string that specifies tf.Example with dictionary.
    491         * In the format of: '<input_key>=<[{feature:value list}]>'
    492 
    493   Returns:
    494     A dictionary that maps input tensor keys to numpy ndarrays.
    495 
    496   Raises:
    497     RuntimeError: An error when a key is specified, but the input file contains
    498         multiple numpy ndarrays, none of which matches the given key.
    499     RuntimeError: An error when no key is specified, but the input file contains
    500         more than one numpy ndarrays.
    501   """
    502   tensor_key_feed_dict = {}
    503 
    504   inputs = preprocess_inputs_arg_string(inputs_str)
    505   input_exprs = preprocess_input_exprs_arg_string(input_exprs_str)
    506   input_examples = preprocess_input_examples_arg_string(input_examples_str)
    507 
    508   for input_tensor_key, (filename, variable_name) in inputs.items():
    509     data = np.load(filename)
    510 
    511     # When a variable_name key is specified for the input file
    512     if variable_name:
    513       # if file contains a single ndarray, ignore the input name
    514       if isinstance(data, np.ndarray):
    515         warnings.warn(
    516             'Input file %s contains a single ndarray. Name key \"%s\" ignored.'
    517             % (filename, variable_name))
    518         tensor_key_feed_dict[input_tensor_key] = data
    519       else:
    520         if variable_name in data:
    521           tensor_key_feed_dict[input_tensor_key] = data[variable_name]
    522         else:
    523           raise RuntimeError(
    524               'Input file %s does not contain variable with name \"%s\".' %
    525               (filename, variable_name))
    526     # When no key is specified for the input file.
    527     else:
    528       # Check if npz file only contains a single numpy ndarray.
    529       if isinstance(data, np.lib.npyio.NpzFile):
    530         variable_name_list = data.files
    531         if len(variable_name_list) != 1:
    532           raise RuntimeError(
    533               'Input file %s contains more than one ndarrays. Please specify '
    534               'the name of ndarray to use.' % filename)
    535         tensor_key_feed_dict[input_tensor_key] = data[variable_name_list[0]]
    536       else:
    537         tensor_key_feed_dict[input_tensor_key] = data
    538 
    539   # When input is a python expression:
    540   for input_tensor_key, py_expr_evaluated in input_exprs.items():
    541     if input_tensor_key in tensor_key_feed_dict:
    542       warnings.warn(
    543           'input_key %s has been specified with both --inputs and --input_exprs'
    544           ' options. Value in --input_exprs will be used.' % input_tensor_key)
    545     tensor_key_feed_dict[input_tensor_key] = py_expr_evaluated
    546 
    547   # When input is a tf.Example:
    548   for input_tensor_key, example in input_examples.items():
    549     if input_tensor_key in tensor_key_feed_dict:
    550       warnings.warn(
    551           'input_key %s has been specified in multiple options. Value in '
    552           '--input_examples will be used.' % input_tensor_key)
    553     tensor_key_feed_dict[input_tensor_key] = example
    554   return tensor_key_feed_dict
    555 
    556 
    557 def show(args):
    558   """Function triggered by show command.
    559 
    560   Args:
    561     args: A namespace parsed from command line.
    562   """
    563   # If all tag is specified, display all information.
    564   if args.all:
    565     _show_all(args.dir)
    566   else:
    567     # If no tag is specified, display all tag_set, if no signaure_def key is
    568     # specified, display all SignatureDef keys, else show input output tensor
    569     # information corresponding to the given SignatureDef key
    570     if args.tag_set is None:
    571       _show_tag_sets(args.dir)
    572     else:
    573       if args.signature_def is None:
    574         _show_signature_def_map_keys(args.dir, args.tag_set)
    575       else:
    576         _show_inputs_outputs(args.dir, args.tag_set, args.signature_def)
    577 
    578 
    579 def run(args):
    580   """Function triggered by run command.
    581 
    582   Args:
    583     args: A namespace parsed from command line.
    584 
    585   Raises:
    586     AttributeError: An error when neither --inputs nor --input_exprs is passed
    587     to run command.
    588   """
    589   if not args.inputs and not args.input_exprs and not args.input_examples:
    590     raise AttributeError(
    591         'At least one of --inputs, --input_exprs or --input_examples must be '
    592         'required')
    593   tensor_key_feed_dict = load_inputs_from_input_arg_string(
    594       args.inputs, args.input_exprs, args.input_examples)
    595   run_saved_model_with_feed_dict(args.dir, args.tag_set, args.signature_def,
    596                                  tensor_key_feed_dict, args.outdir,
    597                                  args.overwrite, tf_debug=args.tf_debug)
    598 
    599 
    600 def create_parser():
    601   """Creates a parser that parse the command line arguments.
    602 
    603   Returns:
    604     A namespace parsed from command line arguments.
    605   """
    606   parser = argparse.ArgumentParser(
    607       description='saved_model_cli: Command-line interface for SavedModel')
    608   parser.add_argument('-v', '--version', action='version', version='0.1.0')
    609 
    610   subparsers = parser.add_subparsers(
    611       title='commands', description='valid commands', help='additional help')
    612 
    613   # show command
    614   show_msg = (
    615       'Usage examples:\n'
    616       'To show all tag-sets in a SavedModel:\n'
    617       '$saved_model_cli show --dir /tmp/saved_model\n'
    618       'To show all available SignatureDef keys in a '
    619       'MetaGraphDef specified by its tag-set:\n'
    620       '$saved_model_cli show --dir /tmp/saved_model --tag_set serve\n'
    621       'For a MetaGraphDef with multiple tags in the tag-set, all tags must be '
    622       'passed in, separated by \';\':\n'
    623       '$saved_model_cli show --dir /tmp/saved_model --tag_set serve,gpu\n\n'
    624       'To show all inputs and outputs TensorInfo for a specific'
    625       ' SignatureDef specified by the SignatureDef key in a'
    626       ' MetaGraph.\n'
    627       '$saved_model_cli show --dir /tmp/saved_model --tag_set serve '
    628       '--signature_def serving_default\n\n'
    629       'To show all available information in the SavedModel\n:'
    630       '$saved_model_cli show --dir /tmp/saved_model --all')
    631   parser_show = subparsers.add_parser(
    632       'show',
    633       description=show_msg,
    634       formatter_class=argparse.RawTextHelpFormatter)
    635   parser_show.add_argument(
    636       '--dir',
    637       type=str,
    638       required=True,
    639       help='directory containing the SavedModel to inspect')
    640   parser_show.add_argument(
    641       '--all',
    642       action='store_true',
    643       help='if set, will output all information in given SavedModel')
    644   parser_show.add_argument(
    645       '--tag_set',
    646       type=str,
    647       default=None,
    648       help='tag-set of graph in SavedModel to show, separated by \',\'')
    649   parser_show.add_argument(
    650       '--signature_def',
    651       type=str,
    652       default=None,
    653       metavar='SIGNATURE_DEF_KEY',
    654       help='key of SignatureDef to display input(s) and output(s) for')
    655   parser_show.set_defaults(func=show)
    656 
    657   # run command
    658   run_msg = ('Usage example:\n'
    659              'To run input tensors from files through a MetaGraphDef and save'
    660              ' the output tensors to files:\n'
    661              '$saved_model_cli show --dir /tmp/saved_model --tag_set serve '
    662              '--signature_def serving_default '
    663              '--inputs input1_key=/tmp/124.npz[x],input2_key=/tmp/123.npy '
    664              '--input_exprs \'input3_key=np.ones(2)\' --input_examples '
    665              '\'input4_key=[{"id":[26],"weights":[0.5, 0.5]}]\' '
    666              '--outdir=/out\n\n'
    667              'For more information about input file format, please see:\n'
    668              'https://www.tensorflow.org/programmers_guide/saved_model_cli\n')
    669   parser_run = subparsers.add_parser(
    670       'run', description=run_msg, formatter_class=argparse.RawTextHelpFormatter)
    671   parser_run.add_argument(
    672       '--dir',
    673       type=str,
    674       required=True,
    675       help='directory containing the SavedModel to execute')
    676   parser_run.add_argument(
    677       '--tag_set',
    678       type=str,
    679       required=True,
    680       help='tag-set of graph in SavedModel to load, separated by \',\'')
    681   parser_run.add_argument(
    682       '--signature_def',
    683       type=str,
    684       required=True,
    685       metavar='SIGNATURE_DEF_KEY',
    686       help='key of SignatureDef to run')
    687   msg = ('Loading inputs from files, in the format of \'<input_key>=<filename>,'
    688          ' or \'<input_key>=<filename>[<variable_name>]\', separated by \';\'.'
    689          ' The file format can only be from .npy, .npz or pickle.')
    690   parser_run.add_argument('--inputs', type=str, default='', help=msg)
    691   msg = ('Specifying inputs by python expressions, in the format of'
    692          ' "<input_key>=\'<python expression>\'", separated by \';\'. '
    693          'numpy module is available as \'np\'. '
    694          'Will override duplicate input keys from --inputs option.')
    695   parser_run.add_argument('--input_exprs', type=str, default='', help=msg)
    696   msg = (
    697       'Specifying tf.Example inputs as list of dictionaries. For example: '
    698       '<input_key>=[{feature0:value_list,feature1:value_list}]. Use ";" to '
    699       'separate input keys. Will override duplicate input keys from --inputs '
    700       'and --input_exprs option.')
    701   parser_run.add_argument('--input_examples', type=str, default='', help=msg)
    702   parser_run.add_argument(
    703       '--outdir',
    704       type=str,
    705       default=None,
    706       help='if specified, output tensor(s) will be saved to given directory')
    707   parser_run.add_argument(
    708       '--overwrite',
    709       action='store_true',
    710       help='if set, output file will be overwritten if it already exists.')
    711   parser_run.add_argument(
    712       '--tf_debug',
    713       action='store_true',
    714       help='if set, will use TensorFlow Debugger (tfdbg) to watch the '
    715            'intermediate Tensors and runtime GraphDefs while running the '
    716            'SavedModel.')
    717   parser_run.set_defaults(func=run)
    718 
    719   return parser
    720 
    721 
    722 def main():
    723   parser = create_parser()
    724   args = parser.parse_args()
    725   if not hasattr(args, 'func'):
    726     parser.error('too few arguments')
    727   args.func(args)
    728 
    729 
    730 if __name__ == '__main__':
    731   sys.exit(main())
    732