1 ## Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 r"""Exports an example linear regression inference graph. 16 17 Exports a TensorFlow graph to `/tmp/saved_model/half_plus_two/` based on the 18 `SavedModel` format. 19 20 This graph calculates, 21 22 \\( 23 y = a*x + b 24 \\) 25 26 and/or, independently, 27 28 \\( 29 y2 = a*x2 + c 30 \\) 31 32 where `a`, `b` and `c` are variables with `a=0.5` and `b=2` and `c=3`. 33 34 Output from this program is typically used to exercise SavedModel load and 35 execution code. 36 """ 37 38 from __future__ import absolute_import 39 from __future__ import division 40 from __future__ import print_function 41 42 import argparse 43 import os 44 import sys 45 46 import tensorflow as tf 47 48 from tensorflow.python.lib.io import file_io 49 50 FLAGS = None 51 52 53 def _write_assets(assets_directory, assets_filename): 54 """Writes asset files to be used with SavedModel for half plus two. 55 56 Args: 57 assets_directory: The directory to which the assets should be written. 58 assets_filename: Name of the file to which the asset contents should be 59 written. 60 61 Returns: 62 The path to which the assets file was written. 63 """ 64 if not file_io.file_exists(assets_directory): 65 file_io.recursive_create_dir(assets_directory) 66 67 path = os.path.join( 68 tf.compat.as_bytes(assets_directory), tf.compat.as_bytes(assets_filename)) 69 file_io.write_string_to_file(path, "asset-file-contents") 70 return path 71 72 73 def _build_regression_signature(input_tensor, output_tensor): 74 """Helper function for building a regression SignatureDef.""" 75 input_tensor_info = tf.saved_model.utils.build_tensor_info(input_tensor) 76 signature_inputs = { 77 tf.saved_model.signature_constants.REGRESS_INPUTS: input_tensor_info 78 } 79 output_tensor_info = tf.saved_model.utils.build_tensor_info(output_tensor) 80 signature_outputs = { 81 tf.saved_model.signature_constants.REGRESS_OUTPUTS: output_tensor_info 82 } 83 return tf.saved_model.signature_def_utils.build_signature_def( 84 signature_inputs, signature_outputs, 85 tf.saved_model.signature_constants.REGRESS_METHOD_NAME) 86 87 88 # Possibly extend this to allow passing in 'classes', but for now this is 89 # sufficient for testing purposes. 90 def _build_classification_signature(input_tensor, scores_tensor): 91 """Helper function for building a classification SignatureDef.""" 92 input_tensor_info = tf.saved_model.utils.build_tensor_info(input_tensor) 93 signature_inputs = { 94 tf.saved_model.signature_constants.CLASSIFY_INPUTS: input_tensor_info 95 } 96 output_tensor_info = tf.saved_model.utils.build_tensor_info(scores_tensor) 97 signature_outputs = { 98 tf.saved_model.signature_constants.CLASSIFY_OUTPUT_SCORES: 99 output_tensor_info 100 } 101 return tf.saved_model.signature_def_utils.build_signature_def( 102 signature_inputs, signature_outputs, 103 tf.saved_model.signature_constants.CLASSIFY_METHOD_NAME) 104 105 106 def _generate_saved_model_for_half_plus_two(export_dir, 107 as_text=False, 108 use_main_op=False): 109 """Generates SavedModel for half plus two. 110 111 Args: 112 export_dir: The directory to which the SavedModel should be written. 113 as_text: Writes the SavedModel protocol buffer in text format to disk. 114 use_main_op: Whether to supply a main op during SavedModel build time. 115 """ 116 builder = tf.saved_model.builder.SavedModelBuilder(export_dir) 117 118 with tf.Session(graph=tf.Graph()) as sess: 119 # Set up the model parameters as variables to exercise variable loading 120 # functionality upon restore. 121 a = tf.Variable(0.5, name="a") 122 b = tf.Variable(2.0, name="b") 123 c = tf.Variable(3.0, name="c") 124 125 # Create a placeholder for serialized tensorflow.Example messages to be fed. 126 serialized_tf_example = tf.placeholder(tf.string, name="tf_example") 127 128 # Parse the tensorflow.Example looking for a feature named "x" with a single 129 # floating point value. 130 feature_configs = { 131 "x": tf.FixedLenFeature( 132 [1], dtype=tf.float32), 133 "x2": tf.FixedLenFeature( 134 [1], dtype=tf.float32, default_value=[0.0]) 135 } 136 tf_example = tf.parse_example(serialized_tf_example, feature_configs) 137 # Use tf.identity() to assign name 138 x = tf.identity(tf_example["x"], name="x") 139 y = tf.add(tf.multiply(a, x), b, name="y") 140 y2 = tf.add(tf.multiply(a, x), c, name="y2") 141 142 x2 = tf.identity(tf_example["x2"], name="x2") 143 y3 = tf.add(tf.multiply(a, x2), c, name="y3") 144 145 # Create an assets file that can be saved and restored as part of the 146 # SavedModel. 147 original_assets_directory = "/tmp/original/export/assets" 148 original_assets_filename = "foo.txt" 149 original_assets_filepath = _write_assets(original_assets_directory, 150 original_assets_filename) 151 152 # Set up the assets collection. 153 assets_filepath = tf.constant(original_assets_filepath) 154 tf.add_to_collection(tf.GraphKeys.ASSET_FILEPATHS, assets_filepath) 155 filename_tensor = tf.Variable( 156 original_assets_filename, 157 name="filename_tensor", 158 trainable=False, 159 collections=[]) 160 assign_filename_op = filename_tensor.assign(original_assets_filename) 161 162 # Set up the signature for Predict with input and output tensor 163 # specification. 164 predict_input_tensor = tf.saved_model.utils.build_tensor_info(x) 165 predict_signature_inputs = {"x": predict_input_tensor} 166 167 predict_output_tensor = tf.saved_model.utils.build_tensor_info(y) 168 predict_signature_outputs = {"y": predict_output_tensor} 169 predict_signature_def = ( 170 tf.saved_model.signature_def_utils.build_signature_def( 171 predict_signature_inputs, predict_signature_outputs, 172 tf.saved_model.signature_constants.PREDICT_METHOD_NAME)) 173 174 signature_def_map = { 175 "regress_x_to_y": 176 _build_regression_signature(serialized_tf_example, y), 177 "regress_x_to_y2": 178 _build_regression_signature(serialized_tf_example, y2), 179 "regress_x2_to_y3": 180 _build_regression_signature(x2, y3), 181 "classify_x_to_y": 182 _build_classification_signature(serialized_tf_example, y), 183 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 184 predict_signature_def 185 } 186 # Initialize all variables and then save the SavedModel. 187 sess.run(tf.global_variables_initializer()) 188 signature_def_map = { 189 "regress_x_to_y": 190 _build_regression_signature(serialized_tf_example, y), 191 "regress_x_to_y2": 192 _build_regression_signature(serialized_tf_example, y2), 193 "regress_x2_to_y3": 194 _build_regression_signature(x2, y3), 195 "classify_x_to_y": 196 _build_classification_signature(serialized_tf_example, y), 197 "classify_x2_to_y3": 198 _build_classification_signature(x2, y3), 199 tf.saved_model.signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY: 200 predict_signature_def 201 } 202 if use_main_op: 203 builder.add_meta_graph_and_variables( 204 sess, [tf.saved_model.tag_constants.SERVING], 205 signature_def_map=signature_def_map, 206 assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), 207 main_op=tf.group(tf.saved_model.main_op.main_op(), 208 assign_filename_op)) 209 else: 210 builder.add_meta_graph_and_variables( 211 sess, [tf.saved_model.tag_constants.SERVING], 212 signature_def_map=signature_def_map, 213 assets_collection=tf.get_collection(tf.GraphKeys.ASSET_FILEPATHS), 214 legacy_init_op=tf.group(assign_filename_op)) 215 builder.save(as_text) 216 217 218 def main(_): 219 _generate_saved_model_for_half_plus_two(FLAGS.output_dir) 220 print("SavedModel generated at: %s" % FLAGS.output_dir) 221 222 _generate_saved_model_for_half_plus_two(FLAGS.output_dir_pbtxt, as_text=True) 223 print("SavedModel generated at: %s" % FLAGS.output_dir_pbtxt) 224 225 _generate_saved_model_for_half_plus_two( 226 FLAGS.output_dir_main_op, use_main_op=True) 227 print("SavedModel generated at: %s" % FLAGS.output_dir_main_op) 228 229 230 if __name__ == "__main__": 231 parser = argparse.ArgumentParser() 232 parser.add_argument( 233 "--output_dir", 234 type=str, 235 default="/tmp/saved_model_half_plus_two", 236 help="Directory where to output SavedModel.") 237 parser.add_argument( 238 "--output_dir_pbtxt", 239 type=str, 240 default="/tmp/saved_model_half_plus_two_pbtxt", 241 help="Directory where to output the text format of SavedModel.") 242 parser.add_argument( 243 "--output_dir_main_op", 244 type=str, 245 default="/tmp/saved_model_half_plus_two_main_op", 246 help="Directory where to output the SavedModel with a main op.") 247 FLAGS, unparsed = parser.parse_known_args() 248 tf.app.run(main=main, argv=[sys.argv[0]] + unparsed) 249