1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Graph Placer.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import time 22 from tensorflow.core.protobuf import config_pb2 23 from tensorflow.core.protobuf import meta_graph_pb2 24 from tensorflow.python.framework import errors 25 from tensorflow.python.framework import ops as tf_ops 26 from tensorflow.python.grappler import cluster as gcluster 27 from tensorflow.python.grappler import hierarchical_controller 28 from tensorflow.python.grappler import item as gitem 29 from tensorflow.python.grappler import tf_optimizer 30 from tensorflow.python.training import training 31 32 33 def PlaceGraph(metagraph, 34 cluster=None, 35 allotted_time=3600, 36 hparams=None, 37 verbose=False): 38 """Place the provided metagraph. 39 40 Args: 41 metagraph: the metagraph to place. 42 cluster: an optional set of hardware resource to optimize the placement for. 43 If none is specified, we'll optimize the placement for the hardware 44 available on the local machine. 45 allotted_time: the maximum amount to time in seconds to spend optimizing 46 the placement. 47 hparams: hyperparameters used to fine tune the placer. 48 verbose: prints debug information if True. 49 50 Returns: 51 The placed metagraph. 52 """ 53 if cluster is None: 54 cluster = gcluster.Cluster() 55 56 # Optimize the metagraph to speedup the placement 57 config = config_pb2.ConfigProto() 58 optimized_graph = tf_optimizer.OptimizeGraph( 59 config, metagraph, verbose=verbose, cluster=cluster) 60 optimized_metagraph = meta_graph_pb2.MetaGraphDef() 61 optimized_metagraph.CopyFrom(metagraph) 62 optimized_metagraph.graph_def.CopyFrom(optimized_graph) 63 64 item = gitem.Item(optimized_metagraph) 65 66 # Measure the runtime achievable with the original placement. 67 try: 68 _, original_run_time, _ = cluster.MeasureCosts(item) 69 if verbose: 70 print("Runtime for original placement: " + str(original_run_time)) 71 except errors.OpError as e: 72 if verbose: 73 print("Original placement isn't feasible: " + str(e)) 74 original_run_time = hparams.failing_signal 75 76 if hparams is None: 77 hparams = hierarchical_controller.hierarchical_controller_hparams() 78 # We run with a single child 79 hparams.num_children = 1 80 81 with tf_ops.Graph().as_default(): 82 # Place all the nodes of the controller on the CPU. We don't want them to 83 # fight for accelerator memory with the model to optimize. 84 with tf_ops.device("/device:CPU:0"): 85 model = hierarchical_controller.HierarchicalController( 86 hparams, item, cluster) 87 ops = model.build_controller() 88 session_creator = training.ChiefSessionCreator() 89 with training.MonitoredSession(session_creator=session_creator) as sess: 90 start_time = time.time() 91 current_time = start_time 92 while current_time - start_time < allotted_time: 93 grouping_actions = model.generate_grouping(sess) 94 input_to_seq2seq = model.create_group_embeddings( 95 grouping_actions, verbose=verbose) 96 model.generate_placement(input_to_seq2seq, sess) 97 try: 98 run_time = model.eval_placement( 99 sess, 100 verbose=verbose) 101 except errors.OpError as e: 102 if verbose: 103 print("Failed to run graph:" + str(e)) 104 run_time = hparams.failing_signal 105 updated = model.update_reward(sess, run_time, verbose=verbose) 106 if updated and run_time < original_run_time: 107 if verbose: 108 print("Found better placement, with runtime " + str(run_time)) 109 model.export_placement(metagraph) 110 111 model.process_reward(sess) 112 113 current_time = time.time() 114 115 return metagraph 116