Home | History | Annotate | Download | only in grappler
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Graph Placer."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import time
     22 from tensorflow.core.protobuf import config_pb2
     23 from tensorflow.core.protobuf import meta_graph_pb2
     24 from tensorflow.python.framework import errors
     25 from tensorflow.python.framework import ops as tf_ops
     26 from tensorflow.python.grappler import cluster as gcluster
     27 from tensorflow.python.grappler import hierarchical_controller
     28 from tensorflow.python.grappler import item as gitem
     29 from tensorflow.python.grappler import tf_optimizer
     30 from tensorflow.python.training import training
     31 
     32 
     33 def PlaceGraph(metagraph,
     34                cluster=None,
     35                allotted_time=3600,
     36                hparams=None,
     37                verbose=False):
     38   """Place the provided metagraph.
     39 
     40   Args:
     41     metagraph: the metagraph to place.
     42     cluster: an optional set of hardware resource to optimize the placement for.
     43       If none is specified, we'll optimize the placement for the hardware
     44       available on the local machine.
     45     allotted_time: the maximum amount to time in seconds to spend optimizing
     46       the placement.
     47     hparams: hyperparameters used to fine tune the placer.
     48     verbose: prints debug information if True.
     49 
     50   Returns:
     51     The placed metagraph.
     52   """
     53   if cluster is None:
     54     cluster = gcluster.Cluster()
     55 
     56   # Optimize the metagraph to speedup the placement
     57   config = config_pb2.ConfigProto()
     58   optimized_graph = tf_optimizer.OptimizeGraph(
     59       config, metagraph, verbose=verbose, cluster=cluster)
     60   optimized_metagraph = meta_graph_pb2.MetaGraphDef()
     61   optimized_metagraph.CopyFrom(metagraph)
     62   optimized_metagraph.graph_def.CopyFrom(optimized_graph)
     63 
     64   item = gitem.Item(optimized_metagraph)
     65 
     66   # Measure the runtime achievable with the original placement.
     67   try:
     68     _, original_run_time, _ = cluster.MeasureCosts(item)
     69     if verbose:
     70       print("Runtime for original placement: " + str(original_run_time))
     71   except errors.OpError as e:
     72     if verbose:
     73       print("Original placement isn't feasible: " + str(e))
     74     original_run_time = hparams.failing_signal
     75 
     76   if hparams is None:
     77     hparams = hierarchical_controller.hierarchical_controller_hparams()
     78   # We run with a single child
     79   hparams.num_children = 1
     80 
     81   with tf_ops.Graph().as_default():
     82     # Place all the nodes of the controller on the CPU. We don't want them to
     83     # fight for accelerator memory with the model to optimize.
     84     with tf_ops.device("/device:CPU:0"):
     85       model = hierarchical_controller.HierarchicalController(
     86           hparams, item, cluster)
     87       ops = model.build_controller()
     88       session_creator = training.ChiefSessionCreator()
     89       with training.MonitoredSession(session_creator=session_creator) as sess:
     90         start_time = time.time()
     91         current_time = start_time
     92         while current_time - start_time < allotted_time:
     93           grouping_actions = model.generate_grouping(sess)
     94           input_to_seq2seq = model.create_group_embeddings(
     95               grouping_actions, verbose=verbose)
     96           model.generate_placement(input_to_seq2seq, sess)
     97           try:
     98             run_time = model.eval_placement(
     99                 sess,
    100                 verbose=verbose)
    101           except errors.OpError as e:
    102             if verbose:
    103               print("Failed to run graph:" + str(e))
    104             run_time = hparams.failing_signal
    105           updated = model.update_reward(sess, run_time, verbose=verbose)
    106           if updated and run_time < original_run_time:
    107             if verbose:
    108               print("Found better placement, with runtime " + str(run_time))
    109             model.export_placement(metagraph)
    110 
    111           model.process_reward(sess)
    112 
    113           current_time = time.time()
    114 
    115   return metagraph
    116