Home | History | Annotate | Download | only in saved_model
      1 ## Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 r"""Exports an example linear regression inference graph.
     16 
     17 Exports a TensorFlow graph to `/tmp/saved_model/half_plus_two/` based on the
     18 `SavedModel` format.
     19 
     20 This graph calculates,
     21 
     22 \\(
     23   y = a*x + b
     24 \\)
     25 
     26 and/or, independently,
     27 
     28 \\(
     29   y2 = a*x2 + c
     30 \\)
     31 
     32 where `a`, `b` and `c` are variables with `a=0.5` and `b=2` and `c=3`.
     33 
     34 Output from this program is typically used to exercise SavedModel load and
     35 execution code.
     36 """
     37 
     38 from __future__ import absolute_import
     39 from __future__ import division
     40 from __future__ import print_function
     41 
     42 import argparse
     43 import os
     44 import sys
     45 
     46 import tensorflow as tf
     47 
     48 from tensorflow.python.lib.io import file_io
     49 
     50 FLAGS = None
     51 
     52 
     53 def _write_assets(assets_directory, assets_filename):
     54   """Writes asset files to be used with SavedModel for half plus two.
     55 
     56   Args:
     57     assets_directory: The directory to which the assets should be written.
     58     assets_filename: Name of the file to which the asset contents should be
     59         written.
     60 
     61   Returns:
     62     The path to which the assets file was written.
     63   """
     64   if not file_io.file_exists(assets_directory):
     65     file_io.recursive_create_dir(assets_directory)
     66 
     67   path = os.path.join(
     68       tf.compat.as_bytes(assets_directory), tf.compat.as_bytes(assets_filename))
     69   file_io.write_string_to_file(path, "asset-file-contents")
     70   return path
     71 
     72 
     73 def _build_regression_signature(input_tensor, output_tensor):
     74   """Helper function for building a regression SignatureDef."""
     75   input_tensor_info = tf.saved_model.utils.build_tensor_info(input_tensor)
     76   signature_inputs = {
     77       tf.saved_model.signature_constants.REGRESS_INPUTS: input_tensor_info
     78   }
     79   output_tensor_info = tf.saved_model.utils.build_tensor_info(output_tensor)
     80   signature_outputs = {
     81       tf.saved_model.signature_constants.REGRESS_OUTPUTS: output_tensor_info
     82   }
     83   return tf.saved_model.signature_def_utils.build_signature_def(
     84       signature_inputs, signature_outputs,
     85       tf.saved_model.signature_constants.REGRESS_METHOD_NAME)
     86 
     87 
     88 # Possibly extend this to allow passing in 'classes', but for now this is
     89 # sufficient for testing purposes.
     90 def _build_classification_signature(input_tensor, scores_tensor):
     91   """Helper function for building a classification SignatureDef."""
     92   input_tensor_info = tf.saved_model.utils.build_tensor_info(input_tensor)
     93   signature_inputs = {
     94       tf.saved_model.signature_constants.CLASSIFY_INPUTS: input_tensor_info
     95   }
     96   output_tensor_info = tf.saved_model.utils.build_tensor_info(scores_tensor)
     97   signature_outputs = {
     98       tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES:
     99           output_tensor_info
    100   }
    101   return tf.saved_model.signature_def_utils.build_signature_def(
    102       signature_inputs, signature_outputs,
    103       tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME)
    104 
    105 
    106 def _generate_saved_model_for_half_plus_two(export_dir,
    107                                             as_text=False,
    108                                             use_main_op=False):
    109   """Generates SavedModel for half plus two.
    110 
    111   Args:
    112     export_dir: The directory to which the SavedModel should be written.
    113     as_text: Writes the SavedModel protocol buffer in text format to disk.
    114     use_main_op: Whether to supply a main op during SavedModel build time.
    115   """
    116   builder = tf.saved_model.builder.SavedModelBuilder(export_dir)
    117 
    118   with tf.Session(graph=tf.Graph()) as sess:
    119     # Set up the model parameters as variables to exercise variable loading
    120     # functionality upon restore.
    121     a = tf.Variable(0.5, name="a")
    122     b = tf.Variable(2.0, name="b")
    123     c = tf.Variable(3.0, name="c")
    124 
    125     # Create a placeholder for serialized tensorflow.Example messages to be fed.
    126     serialized_tf_example = tf.placeholder(tf.string, name="tf_example")
    127 
    128     # Parse the tensorflow.Example looking for a feature named "x" with a single
    129     # floating point value.
    130     feature_configs = {
    131         "x": tf.FixedLenFeature(
    132             [1], dtype=tf.float32),
    133         "x2": tf.FixedLenFeature(
    134             [1], dtype=tf.float32, default_value=[0.0])
    135     }
    136     tf_example = tf.parse_example(serialized_tf_example, feature_configs)
    137     # Use tf.identity() to assign name
    138     x = tf.identity(tf_example["x"], name="x")
    139     y = tf.add(tf.multiply(a, x), b, name="y")
    140     y2 = tf.add(tf.multiply(a, x), c, name="y2")
    141 
    142     x2 = tf.identity(tf_example["x2"], name="x2")
    143     y3 = tf.add(tf.multiply(a, x2), c, name="y3")
    144 
    145     # Create an assets file that can be saved and restored as part of the
    146     # SavedModel.
    147     original_assets_directory = "/tmp/original/export/assets"
    148     original_assets_filename = "foo.txt"
    149     original_assets_filepath = _write_assets(original_assets_directory,
    150                                              original_assets_filename)
    151 
    152     # Set up the assets collection.
    153     assets_filepath = tf.constant(original_assets_filepath)
    154     tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, assets_filepath)
    155     filename_tensor = tf.Variable(
    156         original_assets_filename,
    157         name="filename_tensor",
    158         trainable=False,
    159         collections=[])
    160     assign_filename_op = filename_tensor.assign(original_assets_filename)
    161 
    162     # Set up the signature for Predict with input and output tensor
    163     # specification.
    164     predict_input_tensor = tf.saved_model.utils.build_tensor_info(x)
    165     predict_signature_inputs = {"x": predict_input_tensor}
    166 
    167     predict_output_tensor = tf.saved_model.utils.build_tensor_info(y)
    168     predict_signature_outputs = {"y": predict_output_tensor}
    169     predict_signature_def = (
    170         tf.saved_model.signature_def_utils.build_signature_def(
    171             predict_signature_inputs, predict_signature_outputs,
    172             tf.saved_model.signature_constants.PREDICT_METHOD_NAME))
    173 
    174     signature_def_map = {
    175         "regress_x_to_y":
    176             _build_regression_signature(serialized_tf_example, y),
    177         "regress_x_to_y2":
    178             _build_regression_signature(serialized_tf_example, y2),
    179         "regress_x2_to_y3":
    180             _build_regression_signature(x2, y3),
    181         "classify_x_to_y":
    182             _build_classification_signature(serialized_tf_example, y),
    183         tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
    184             predict_signature_def
    185     }
    186     # Initialize all variables and then save the SavedModel.
    187     sess.run(tf.global_variables_initializer())
    188     signature_def_map = {
    189         "regress_x_to_y":
    190             _build_regression_signature(serialized_tf_example, y),
    191         "regress_x_to_y2":
    192             _build_regression_signature(serialized_tf_example, y2),
    193         "regress_x2_to_y3":
    194             _build_regression_signature(x2, y3),
    195         "classify_x_to_y":
    196             _build_classification_signature(serialized_tf_example, y),
    197         "classify_x2_to_y3":
    198             _build_classification_signature(x2, y3),
    199         tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY:
    200             predict_signature_def
    201     }
    202     if use_main_op:
    203       builder.add_meta_graph_and_variables(
    204           sess, [tf.saved_model.tag_constants.SERVING],
    205           signature_def_map=signature_def_map,
    206           assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
    207           main_op=tf.group(tf.saved_model.main_op.main_op(),
    208                            assign_filename_op))
    209     else:
    210       builder.add_meta_graph_and_variables(
    211           sess, [tf.saved_model.tag_constants.SERVING],
    212           signature_def_map=signature_def_map,
    213           assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS),
    214           legacy_init_op=tf.group(assign_filename_op))
    215     builder.save(as_text)
    216 
    217 
    218 def main(_):
    219   _generate_saved_model_for_half_plus_two(FLAGS.output_dir)
    220   print("SavedModel generated at: %s" % FLAGS.output_dir)
    221 
    222   _generate_saved_model_for_half_plus_two(FLAGS.output_dir_pbtxt, as_text=True)
    223   print("SavedModel generated at: %s" % FLAGS.output_dir_pbtxt)
    224 
    225   _generate_saved_model_for_half_plus_two(
    226       FLAGS.output_dir_main_op, use_main_op=True)
    227   print("SavedModel generated at: %s" % FLAGS.output_dir_main_op)
    228 
    229 
    230 if __name__ == "__main__":
    231   parser = argparse.ArgumentParser()
    232   parser.add_argument(
    233       "--output_dir",
    234       type=str,
    235       default="/tmp/saved_model_half_plus_two",
    236       help="Directory where to output SavedModel.")
    237   parser.add_argument(
    238       "--output_dir_pbtxt",
    239       type=str,
    240       default="/tmp/saved_model_half_plus_two_pbtxt",
    241       help="Directory where to output the text format of SavedModel.")
    242   parser.add_argument(
    243       "--output_dir_main_op",
    244       type=str,
    245       default="/tmp/saved_model_half_plus_two_main_op",
    246       help="Directory where to output the SavedModel with a main op.")
    247   FLAGS, unparsed = parser.parse_known_args()
    248   tf.app.run(main=main, argv=[sys.argv[0]] + unparsed)
    249