1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Simple script to write Inception-ResNet-v2 model to graph file. 16 """ 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import argparse 23 import sys 24 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import graph_io 27 from tensorflow.python.framework import ops 28 from tensorflow.python.ops import array_ops 29 from tensorflow.python.platform import app 30 from nets import inception 31 32 cmd_args = None 33 34 35 def main(unused_argv): 36 # Model definition. 37 g = ops.Graph() 38 with g.as_default(): 39 images = array_ops.placeholder( 40 dtypes.float32, shape=(1, None, None, 3), name='input_image') 41 inception.inception_resnet_v2_base(images) 42 43 graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir, 44 cmd_args.graph_filename) 45 46 47 if __name__ == '__main__': 48 parser = argparse.ArgumentParser() 49 parser.register('type', 'bool', lambda v: v.lower() == 'true') 50 parser.add_argument( 51 '--graph_dir', 52 type=str, 53 default='/tmp', 54 help='Directory where graph will be saved.') 55 parser.add_argument( 56 '--graph_filename', 57 type=str, 58 default='graph.pbtxt', 59 help='Filename of graph that will be saved.') 60 cmd_args, unparsed = parser.parse_known_args() 61 app.run(main=main, argv=[sys.argv[0]] + unparsed) 62