Home | History | Annotate | Download | only in examples
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Simple script to write Inception-ResNet-v2 model to graph file.
     16 """
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import argparse
     23 import sys
     24 
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import graph_io
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.ops import array_ops
     29 from tensorflow.python.platform import app
     30 from nets import inception
     31 
     32 cmd_args = None
     33 
     34 
     35 def main(unused_argv):
     36   # Model definition.
     37   g = ops.Graph()
     38   with g.as_default():
     39     images = array_ops.placeholder(
     40         dtypes.float32, shape=(1, None, None, 3), name='input_image')
     41     inception.inception_resnet_v2_base(images)
     42 
     43   graph_io.write_graph(g.as_graph_def(), cmd_args.graph_dir,
     44                        cmd_args.graph_filename)
     45 
     46 
     47 if __name__ == '__main__':
     48   parser = argparse.ArgumentParser()
     49   parser.register('type', 'bool', lambda v: v.lower() == 'true')
     50   parser.add_argument(
     51       '--graph_dir',
     52       type=str,
     53       default='/tmp',
     54       help='Directory where graph will be saved.')
     55   parser.add_argument(
     56       '--graph_filename',
     57       type=str,
     58       default='graph.pbtxt',
     59       help='Filename of graph that will be saved.')
     60   cmd_args, unparsed = parser.parse_known_args()
     61   app.run(main=main, argv=[sys.argv[0]] + unparsed)
     62