Home | History | Annotate | Download | only in grappler
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """HierarchicalController Class.
     16 
     17 The HierarchicalController encompasses the entire lifecycle of training the
     18 device placement policy, including generating op embeddings, getting groups for
     19 each op, placing those groups and running the predicted placements.
     20 
     21 Different assignment models can inherit from this class.
     22 """
     23 
     24 from __future__ import absolute_import
     25 from __future__ import division
     26 from __future__ import print_function
     27 
     28 import math
     29 import numpy as np
     30 import six
     31 from tensorflow.python.framework import constant_op
     32 from tensorflow.python.framework import dtypes
     33 from tensorflow.python.framework import errors
     34 from tensorflow.python.framework import ops as tf_ops
     35 from tensorflow.python.grappler.controller import Controller
     36 from tensorflow.python.ops import array_ops
     37 from tensorflow.python.ops import clip_ops
     38 from tensorflow.python.ops import control_flow_ops
     39 from tensorflow.python.ops import embedding_ops
     40 from tensorflow.python.ops import init_ops
     41 from tensorflow.python.ops import linalg_ops
     42 from tensorflow.python.ops import math_ops
     43 from tensorflow.python.ops import nn_ops
     44 from tensorflow.python.ops import random_ops
     45 from tensorflow.python.ops import state_ops
     46 from tensorflow.python.ops import tensor_array_ops
     47 from tensorflow.python.ops import variable_scope
     48 from tensorflow.python.summary import summary
     49 from tensorflow.python.training import adam
     50 from tensorflow.python.training import gradient_descent
     51 from tensorflow.python.training import learning_rate_decay
     52 from tensorflow.python.training import training_util
     53 
     54 
     55 class PlacerParams(object):
     56   """Class to hold a set of placement parameters as name-value pairs.
     57 
     58   A typical usage is as follows:
     59 
     60   ```python
     61   # Create a PlacerParams object specifying names and values of the model
     62   # parameters:
     63   params = PlacerParams(hidden_size=128, decay_steps=50)
     64 
     65   # The parameters are available as attributes of the PlacerParams object:
     66   hparams.hidden_size ==> 128
     67   hparams.decay_steps ==> 50
     68   ```
     69 
     70   """
     71 
     72   def __init__(self, **kwargs):
     73     """Create an instance of `PlacerParams` from keyword arguments.
     74 
     75     The keyword arguments specify name-values pairs for the parameters.
     76     The parameter types are inferred from the type of the values passed.
     77 
     78     The parameter names are added as attributes of `PlacerParams` object,
     79     and they can be accessed directly with the dot notation `params._name_`.
     80 
     81     Example:
     82 
     83     ```python
     84     # Define 1 parameter: 'hidden_size'
     85     params = PlacerParams(hidden_size=128)
     86     params.hidden_size ==> 128
     87     ```
     88 
     89     Args:
     90       **kwargs: Key-value pairs where the key is the parameter name and
     91         the value is the value for the parameter.
     92     """
     93     for name, value in six.iteritems(kwargs):
     94       self.add_param(name, value)
     95 
     96   def add_param(self, name, value):
     97     """Adds {name, value} pair to hyperparameters.
     98 
     99     Args:
    100       name: Name of the hyperparameter.
    101       value: Value of the hyperparameter. Can be one of the following types:
    102         int, float, string, int list, float list, or string list.
    103 
    104     Raises:
    105       ValueError: if one of the arguments is invalid.
    106     """
    107     # Keys in kwargs are unique, but 'name' could be the name of a pre-existing
    108     # attribute of this object.  In that case we refuse to use it as a
    109     # parameter name.
    110     if getattr(self, name, None) is not None:
    111       raise ValueError("Parameter name is reserved: %s" % name)
    112     setattr(self, name, value)
    113 
    114 
    115 def hierarchical_controller_hparams():
    116   """Hyperparameters for hierarchical planner."""
    117   return PlacerParams(
    118       hidden_size=512,
    119       forget_bias_init=1.0,
    120       temperature=1.0,
    121       logits_std_noise=0.5,
    122       stop_noise_step=750,
    123       decay_steps=50,
    124       max_num_outputs=5,
    125       max_output_size=5,
    126       tanh_constant=1.0,
    127       adj_embed_dim=20,
    128       grouping_hidden_size=64,
    129       num_groups=None,
    130       bi_lstm=True,
    131       failing_signal=100,
    132       stop_sampling=500,
    133       start_with_failing_signal=True,
    134       always_update_baseline=False,
    135       bl_dec=0.9,
    136       grad_bound=1.0,
    137       lr=0.1,
    138       lr_dec=0.95,
    139       start_decay_step=400,
    140       optimizer_type="adam",
    141       stop_updating_after_steps=1000,
    142       name="hierarchical_controller",
    143       keep_prob=1.0,
    144       reward_function="sqrt",
    145       seed=1234,
    146       # distributed training params
    147       num_children=1)
    148 
    149 
    150 class HierarchicalController(Controller):
    151   """HierarchicalController class."""
    152 
    153   def __init__(self, hparams, item, cluster, controller_id=0):
    154     """HierarchicalController class initializer.
    155 
    156     Args:
    157       hparams: All hyper-parameters.
    158       item: The metagraph to place.
    159       cluster: The cluster of hardware devices to optimize for.
    160       controller_id: the id of the controller in a multi-controller setup.
    161     """
    162     super(HierarchicalController, self).__init__(item, cluster)
    163     self.ctrl_id = controller_id
    164     self.hparams = hparams
    165 
    166     if self.hparams.num_groups is None:
    167       self.num_groups = min(256, 20 * self.num_devices)
    168     else:
    169       self.num_groups = self.hparams.num_groups
    170 
    171     # creates self.op_embeddings and self.type_dict
    172     self.create_op_embeddings(verbose=False)
    173     # TODO(azalia) clean up embedding/group_embedding_size names
    174     self.group_emb_size = (
    175         2 * self.num_groups + len(self.type_dict) +
    176         self.hparams.max_num_outputs * self.hparams.max_output_size)
    177     self.embedding_size = self.group_emb_size
    178     self.initializer = init_ops.glorot_uniform_initializer(
    179         seed=self.hparams.seed)
    180 
    181     with variable_scope.variable_scope(
    182         self.hparams.name,
    183         initializer=self.initializer,
    184         reuse=variable_scope.AUTO_REUSE):
    185       # define parameters of feedforward
    186       variable_scope.get_variable("w_grouping_ff", [
    187           1 + self.hparams.max_num_outputs * self.hparams.max_output_size +
    188           self.hparams.adj_embed_dim, self.hparams.grouping_hidden_size
    189       ])
    190       variable_scope.get_variable(
    191           "w_grouping_softmax",
    192           [self.hparams.grouping_hidden_size, self.num_groups])
    193       if self.hparams.bi_lstm:
    194         variable_scope.get_variable("encoder_lstm_forward", [
    195             self.embedding_size + self.hparams.hidden_size / 2,
    196             2 * self.hparams.hidden_size
    197         ])
    198         variable_scope.get_variable("encoder_lstm_backward", [
    199             self.embedding_size + self.hparams.hidden_size / 2,
    200             2 * self.hparams.hidden_size
    201         ])
    202         variable_scope.get_variable(
    203             "device_embeddings", [self.num_devices, self.hparams.hidden_size])
    204         variable_scope.get_variable(
    205             "decoder_lstm",
    206             [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size])
    207         variable_scope.get_variable(
    208             "device_softmax", [2 * self.hparams.hidden_size, self.num_devices])
    209         variable_scope.get_variable("device_go_embedding",
    210                                     [1, self.hparams.hidden_size])
    211         variable_scope.get_variable(
    212             "encoder_forget_bias",
    213             shape=1,
    214             dtype=dtypes.float32,
    215             initializer=init_ops.constant_initializer(
    216                 self.hparams.forget_bias_init))
    217         variable_scope.get_variable(
    218             "decoder_forget_bias",
    219             shape=1,
    220             dtype=dtypes.float32,
    221             initializer=init_ops.constant_initializer(
    222                 self.hparams.forget_bias_init))
    223         variable_scope.get_variable(
    224             "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size])
    225         variable_scope.get_variable(
    226             "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
    227         variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
    228 
    229       else:
    230         variable_scope.get_variable("encoder_lstm", [
    231             self.embedding_size + self.hparams.hidden_size,
    232             4 * self.hparams.hidden_size
    233         ])
    234         variable_scope.get_variable(
    235             "device_embeddings", [self.num_devices, self.hparams.hidden_size])
    236         variable_scope.get_variable(
    237             "decoder_lstm",
    238             [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size])
    239         variable_scope.get_variable(
    240             "device_softmax", [2 * self.hparams.hidden_size, self.num_devices])
    241         variable_scope.get_variable("device_go_embedding",
    242                                     [1, self.hparams.hidden_size])
    243         variable_scope.get_variable(
    244             "encoder_forget_bias",
    245             shape=1,
    246             dtype=dtypes.float32,
    247             initializer=init_ops.constant_initializer(
    248                 self.hparams.forget_bias_init))
    249         variable_scope.get_variable(
    250             "decoder_forget_bias",
    251             shape=1,
    252             dtype=dtypes.float32,
    253             initializer=init_ops.constant_initializer(
    254                 self.hparams.forget_bias_init))
    255         variable_scope.get_variable(
    256             "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size])
    257         variable_scope.get_variable(
    258             "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size])
    259         variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1])
    260     seq2seq_input_layer = array_ops.placeholder_with_default(
    261         array_ops.zeros([self.hparams.num_children,
    262                          self.num_groups,
    263                          self.group_emb_size],
    264                         dtypes.float32),
    265         shape=(self.hparams.num_children, self.num_groups, self.group_emb_size))
    266     self.seq2seq_input_layer = seq2seq_input_layer
    267 
    268   def compute_reward(self, run_time):
    269     if self.hparams.reward_function == "id":
    270       reward = run_time
    271     elif self.hparams.reward_function == "sqrt":
    272       reward = math.sqrt(run_time)
    273     elif self.hparams.reward_function == "log":
    274       reward = math.log1p(run_time)
    275     else:
    276       raise NotImplementedError(
    277           "Unrecognized reward function '%s', consider your "
    278           "--reward_function flag value." % self.hparams.reward_function)
    279     return reward
    280 
    281   def build_controller(self):
    282     """RL optimization interface.
    283 
    284     Returns:
    285       ops: A dictionary holding handles of the model used for training.
    286     """
    287 
    288     self._global_step = training_util.get_or_create_global_step()
    289     ops = {}
    290     ops["loss"] = 0
    291 
    292     failing_signal = self.compute_reward(self.hparams.failing_signal)
    293 
    294     ctr = {}
    295 
    296     with tf_ops.name_scope("controller_{}".format(self.ctrl_id)):
    297       with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
    298         ctr["reward"] = {"value": [], "ph": [], "update": []}
    299         ctr["ready"] = {"value": [], "ph": [], "update": []}
    300         ctr["best_reward"] = {"value": [], "update": []}
    301         for i in range(self.hparams.num_children):
    302           reward_value = variable_scope.get_local_variable(
    303               "reward_{}".format(i),
    304               initializer=0.0,
    305               dtype=dtypes.float32,
    306               trainable=False)
    307           reward_ph = array_ops.placeholder(
    308               dtypes.float32, shape=(), name="reward_ph_{}".format(i))
    309           reward_update = state_ops.assign(
    310               reward_value, reward_ph, use_locking=True)
    311           ctr["reward"]["value"].append(reward_value)
    312           ctr["reward"]["ph"].append(reward_ph)
    313           ctr["reward"]["update"].append(reward_update)
    314           best_reward = variable_scope.get_local_variable(
    315               "best_reward_{}".format(i),
    316               initializer=failing_signal,
    317               dtype=dtypes.float32,
    318               trainable=False)
    319           ctr["best_reward"]["value"].append(best_reward)
    320           ctr["best_reward"]["update"].append(
    321               state_ops.assign(best_reward,
    322                                math_ops.minimum(best_reward, reward_update)))
    323 
    324           ready_value = variable_scope.get_local_variable(
    325               "ready_{}".format(i),
    326               initializer=True,
    327               dtype=dtypes.bool,
    328               trainable=False)
    329           ready_ph = array_ops.placeholder(
    330               dtypes.bool, shape=(), name="ready_ph_{}".format(i))
    331           ready_update = state_ops.assign(
    332               ready_value, ready_ph, use_locking=True)
    333           ctr["ready"]["value"].append(ready_value)
    334           ctr["ready"]["ph"].append(ready_ph)
    335           ctr["ready"]["update"].append(ready_update)
    336 
    337       ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings()
    338       summary.histogram(
    339           "grouping_actions",
    340           array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0],
    341                           [1, array_ops.shape(self.op_embeddings)[0]]))
    342 
    343       with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
    344         ctr["baseline"] = variable_scope.get_local_variable(
    345             "baseline",
    346             initializer=failing_signal
    347             if self.hparams.start_with_failing_signal else 0.0,
    348             dtype=dtypes.float32,
    349             trainable=False)
    350 
    351       new_baseline = self.hparams.bl_dec * ctr["baseline"] + (
    352           1 - self.hparams.bl_dec) * math_ops.reduce_mean(
    353               ctr["reward"]["value"])
    354       if not self.hparams.always_update_baseline:
    355         baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal)
    356         selected_reward = array_ops.boolean_mask(ctr["reward"]["value"],
    357                                                  baseline_mask)
    358         selected_baseline = control_flow_ops.cond(
    359             math_ops.reduce_any(baseline_mask),
    360             lambda: math_ops.reduce_mean(selected_reward),
    361             lambda: constant_op.constant(0, dtype=dtypes.float32))
    362         ctr["pos_reward"] = selected_baseline
    363         pos_ = math_ops.less(
    364             constant_op.constant(0, dtype=dtypes.float32), selected_baseline)
    365         selected_baseline = self.hparams.bl_dec * ctr["baseline"] + (
    366             1 - self.hparams.bl_dec) * selected_baseline
    367         selected_baseline = control_flow_ops.cond(
    368             pos_, lambda: selected_baseline, lambda: ctr["baseline"])
    369         new_baseline = control_flow_ops.cond(
    370             math_ops.less(self.global_step,
    371                           self.hparams.stop_updating_after_steps),
    372             lambda: new_baseline, lambda: selected_baseline)
    373       ctr["baseline_update"] = state_ops.assign(
    374           ctr["baseline"], new_baseline, use_locking=True)
    375 
    376       ctr["y_preds"], ctr["log_probs"] = self.get_placements()
    377       summary.histogram("actions", ctr["y_preds"]["sample"])
    378       mask = math_ops.less(ctr["reward"]["value"], failing_signal)
    379       ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"]
    380       ctr["loss"] *= (
    381           ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"])
    382 
    383       selected_loss = array_ops.boolean_mask(ctr["loss"], mask)
    384       selected_loss = control_flow_ops.cond(
    385           math_ops.reduce_any(mask),
    386           lambda: math_ops.reduce_mean(-selected_loss),
    387           lambda: constant_op.constant(0, dtype=dtypes.float32))
    388 
    389       ctr["loss"] = control_flow_ops.cond(
    390           math_ops.less(self.global_step,
    391                         self.hparams.stop_updating_after_steps),
    392           lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss)
    393 
    394       ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"])
    395       summary.scalar("loss", ctr["loss"])
    396       summary.scalar("avg_reward", ctr["reward_s"])
    397       summary.scalar("best_reward_so_far", best_reward)
    398       summary.scalar(
    399           "advantage",
    400           math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"]))
    401 
    402     with variable_scope.variable_scope(
    403         "optimizer", reuse=variable_scope.AUTO_REUSE):
    404       (ctr["train_op"], ctr["lr"], ctr["grad_norm"],
    405        ctr["grad_norms"]) = self._get_train_ops(
    406            ctr["loss"],
    407            tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES),
    408            self.global_step,
    409            grad_bound=self.hparams.grad_bound,
    410            lr_init=self.hparams.lr,
    411            lr_dec=self.hparams.lr_dec,
    412            start_decay_step=self.hparams.start_decay_step,
    413            decay_steps=self.hparams.decay_steps,
    414            optimizer_type=self.hparams.optimizer_type)
    415 
    416     summary.scalar("gradnorm", ctr["grad_norm"])
    417     summary.scalar("lr", ctr["lr"])
    418     ctr["summary"] = summary.merge_all()
    419     ops["controller"] = ctr
    420 
    421     self.ops = ops
    422     return ops
    423 
    424   @property
    425   def global_step(self):
    426     return self._global_step
    427 
    428   def create_op_embeddings(self, verbose=False):
    429     if verbose:
    430       print("process input graph for op embeddings")
    431     self.num_ops = len(self.important_ops)
    432     # topological sort of important nodes
    433     topo_order = [op.name for op in self.important_ops]
    434 
    435     # create index to name for topologicaly sorted important nodes
    436     name_to_topo_order_index = {}
    437     for idx, x in enumerate(topo_order):
    438       name_to_topo_order_index[x] = idx
    439     self.name_to_topo_order_index = name_to_topo_order_index
    440 
    441     # create adj matrix
    442     adj_dict = {}
    443     for idx, op in enumerate(self.important_ops):
    444       for output_op in self.get_node_fanout(op):
    445         output_op_name = output_op.name
    446         if output_op_name in self.important_op_names:
    447           if name_to_topo_order_index[op.name] not in adj_dict:
    448             adj_dict[name_to_topo_order_index[op.name]] = []
    449           adj_dict[name_to_topo_order_index[op.name]].extend(
    450               [name_to_topo_order_index[output_op_name], 1])
    451           if output_op_name not in adj_dict:
    452             adj_dict[name_to_topo_order_index[output_op_name]] = []
    453           adj_dict[name_to_topo_order_index[output_op_name]].extend(
    454               [name_to_topo_order_index[op.name], -1])
    455 
    456     # get op_type op_output_shape, and adj info
    457     output_embed_dim = (self.hparams.max_num_outputs *
    458                         self.hparams.max_output_size)
    459 
    460     # TODO(bsteiner): don't filter based on used ops so that we can generalize
    461     # to models that use other types of ops.
    462     used_ops = set()
    463     for node in self.important_ops:
    464       op_type = str(node.op)
    465       used_ops.add(op_type)
    466 
    467     self.type_dict = {}
    468     for op_type in self.cluster.ListAvailableOps():
    469       if op_type in used_ops:
    470         self.type_dict[op_type] = len(self.type_dict)
    471 
    472     op_types = np.zeros([self.num_ops], dtype=np.int32)
    473     op_output_shapes = np.full(
    474         [self.num_ops, output_embed_dim], -1.0, dtype=np.float32)
    475     for idx, node in enumerate(self.important_ops):
    476       op_types[idx] = self.type_dict[node.op]
    477       # output shape
    478       op_name = node.name
    479       for i, output_prop in enumerate(self.node_properties[op_name]):
    480         if output_prop.shape.__str__() == "<unknown>":
    481           continue
    482         shape = output_prop.shape
    483         for j, dim in enumerate(shape.dim):
    484           if dim.size >= 0:
    485             if i * self.hparams.max_output_size + j >= output_embed_dim:
    486               break
    487             op_output_shapes[idx,
    488                              i * self.hparams.max_output_size + j] = dim.size
    489     # adj for padding
    490     op_adj = np.full(
    491         [self.num_ops, self.hparams.adj_embed_dim], 0, dtype=np.float32)
    492     for idx in adj_dict:
    493       neighbors = adj_dict[int(idx)]
    494       min_dim = min(self.hparams.adj_embed_dim, len(neighbors))
    495       padding_size = self.hparams.adj_embed_dim - min_dim
    496       neighbors = neighbors[:min_dim] + [0] * padding_size
    497       op_adj[int(idx)] = neighbors
    498 
    499     # op_embedding   starts here
    500     op_embeddings = np.zeros(
    501         [
    502             self.num_ops,
    503             1 + self.hparams.max_num_outputs * self.hparams.max_output_size +
    504             self.hparams.adj_embed_dim
    505         ],
    506         dtype=np.float32)
    507     for idx, op_name in enumerate(topo_order):
    508       op_embeddings[idx] = np.concatenate(
    509           (np.array([op_types[idx]]), op_output_shapes[idx], op_adj[int(idx)]))
    510     self.op_embeddings = constant_op.constant(
    511         op_embeddings, dtype=dtypes.float32)
    512     if verbose:
    513       print("num_ops = {}".format(self.num_ops))
    514       print("num_types = {}".format(len(self.type_dict)))
    515 
    516   def get_groupings(self, *args, **kwargs):
    517     num_children = self.hparams.num_children
    518     with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
    519       grouping_actions_cache = variable_scope.get_local_variable(
    520           "grouping_actions_cache",
    521           initializer=init_ops.zeros_initializer,
    522           dtype=dtypes.int32,
    523           shape=[num_children, self.num_ops],
    524           trainable=False)
    525     input_layer = self.op_embeddings
    526     input_layer = array_ops.expand_dims(input_layer, 0)
    527     feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1])
    528     grouping_actions, grouping_log_probs = {}, {}
    529     grouping_actions["sample"], grouping_log_probs[
    530         "sample"] = self.make_grouping_predictions(feed_ff_input_layer)
    531 
    532     grouping_actions["sample"] = state_ops.assign(grouping_actions_cache,
    533                                                   grouping_actions["sample"])
    534     self.grouping_actions_cache = grouping_actions_cache
    535 
    536     return grouping_actions, grouping_log_probs
    537 
    538   def make_grouping_predictions(self, input_layer, reuse=None):
    539     """model that predicts grouping (grouping_actions).
    540 
    541     Args:
    542       input_layer: group_input_layer
    543       reuse: reuse
    544 
    545     Returns:
    546        grouping_actions: actions
    547        grouping_log_probs: log probabilities corresponding to actions
    548     """
    549     with variable_scope.variable_scope(self.hparams.name, reuse=True):
    550       # input_layer: tensor of size [1, num_ops, hidden_size]
    551       w_grouping_ff = variable_scope.get_variable("w_grouping_ff")
    552       w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax")
    553 
    554     batch_size = array_ops.shape(input_layer)[0]
    555     embedding_dim = array_ops.shape(input_layer)[2]
    556 
    557     reshaped = array_ops.reshape(input_layer,
    558                                  [batch_size * self.num_ops, embedding_dim])
    559     ff_output = math_ops.matmul(reshaped, w_grouping_ff)
    560     logits = math_ops.matmul(ff_output, w_grouping_softmax)
    561     if self.hparams.logits_std_noise > 0:
    562       num_in_logits = math_ops.cast(
    563           array_ops.size(logits), dtype=dtypes.float32)
    564       avg_norm = math_ops.divide(
    565           linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
    566       logits_noise = random_ops.random_normal(
    567           array_ops.shape(logits),
    568           stddev=self.hparams.logits_std_noise * avg_norm)
    569       logits = control_flow_ops.cond(
    570           self.global_step > self.hparams.stop_noise_step, lambda: logits,
    571           lambda: logits + logits_noise)
    572     logits = array_ops.reshape(logits,
    573                                [batch_size * self.num_ops, self.num_groups])
    574     actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
    575     actions = math_ops.cast(actions, dtypes.int32)
    576     actions = array_ops.reshape(actions, [batch_size, self.num_ops])
    577     action_label = array_ops.reshape(actions, [-1])
    578     log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits(
    579         logits=logits, labels=action_label)
    580     log_probs = array_ops.reshape(log_probs, [batch_size, -1])
    581     log_probs = math_ops.reduce_sum(log_probs, 1)
    582     grouping_actions = actions
    583     grouping_log_probs = log_probs
    584     return grouping_actions, grouping_log_probs
    585 
    586   def create_group_embeddings(self, grouping_actions, verbose=False):
    587     """Approximating the blocks of a TF graph from a graph_def.
    588 
    589     Args:
    590       grouping_actions: grouping predictions.
    591       verbose: print stuffs.
    592 
    593     Returns:
    594       groups: list of groups.
    595     """
    596     groups = [
    597         self._create_group_embeddings(grouping_actions, i, verbose) for
    598         i in range(self.hparams.num_children)
    599     ]
    600     return np.stack(groups, axis=0)
    601 
    602   def _create_group_embeddings(self, grouping_actions, child_id, verbose=False):
    603     """Approximating the blocks of a TF graph from a graph_def for each child.
    604 
    605     Args:
    606       grouping_actions: grouping predictions.
    607       child_id: child_id for the group.
    608       verbose: print stuffs.
    609 
    610     Returns:
    611       groups: group embedding for the child_id.
    612     """
    613     if verbose:
    614       print("Processing input_graph")
    615 
    616     # TODO(azalia): Build inter-adjacencies dag matrix.
    617     # record dag_matrix
    618     dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32)
    619     for op in self.important_ops:
    620       topo_op_index = self.name_to_topo_order_index[op.name]
    621       group_index = grouping_actions[child_id][topo_op_index]
    622       for output_op in self.get_node_fanout(op):
    623         if output_op.name not in self.important_op_names:
    624           continue
    625         output_group_index = (
    626             grouping_actions[child_id][self.name_to_topo_order_index[
    627                 output_op.name]])
    628         dag_matrix[group_index, output_group_index] += 1.0
    629     num_connections = np.sum(dag_matrix)
    630     num_intra_group_connections = dag_matrix.trace()
    631     num_inter_group_connections = num_connections - num_intra_group_connections
    632     if verbose:
    633       print("grouping evaluation metric")
    634       print(("num_connections={} num_intra_group_connections={} "
    635              "num_inter_group_connections={}").format(
    636                  num_connections, num_intra_group_connections,
    637                  num_inter_group_connections))
    638     self.dag_matrix = dag_matrix
    639 
    640     # output_shape
    641     op_output_shapes = np.zeros(
    642         [
    643             len(self.important_ops),
    644             self.hparams.max_num_outputs * self.hparams.max_output_size
    645         ],
    646         dtype=np.float32)
    647 
    648     for idx, op in enumerate(self.important_ops):
    649       for i, output_properties in enumerate(self.node_properties[op.name]):
    650         if output_properties.shape.__str__() == "<unknown>":
    651           continue
    652         if i > self.hparams.max_num_outputs:
    653           break
    654         shape = output_properties.shape
    655         for j, dim in enumerate(shape.dim):
    656           if dim.size > 0:
    657             k = i * self.hparams.max_output_size + j
    658             if k >= self.hparams.max_num_outputs * self.hparams.max_output_size:
    659               break
    660             op_output_shapes[idx, k] = dim.size
    661 
    662     # group_embedding
    663     group_embedding = np.zeros(
    664         [
    665             self.num_groups, len(self.type_dict) +
    666             self.hparams.max_num_outputs * self.hparams.max_output_size
    667         ],
    668         dtype=np.float32)
    669     for op_index, op in enumerate(self.important_ops):
    670       group_index = grouping_actions[child_id][
    671           self.name_to_topo_order_index[op.name]]
    672       type_name = str(op.op)
    673       type_index = self.type_dict[type_name]
    674       group_embedding[group_index, type_index] += 1
    675       group_embedding[group_index, :self.hparams.max_num_outputs * self.hparams.
    676                       max_output_size] += (
    677                           op_output_shapes[op_index])
    678     grouping_adjacencies = np.concatenate(
    679         [dag_matrix, np.transpose(dag_matrix)], axis=1)
    680     group_embedding = np.concatenate(
    681         [grouping_adjacencies, group_embedding], axis=1)
    682     group_normalizer = np.amax(group_embedding, axis=1, keepdims=True)
    683     group_embedding /= (group_normalizer + 1.0)
    684     if verbose:
    685       print("Finished Processing Input Graph")
    686     return group_embedding
    687 
    688   def get_placements(self, *args, **kwargs):
    689     num_children = self.hparams.num_children
    690     with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)):
    691       actions_cache = variable_scope.get_local_variable(
    692           "actions_cache",
    693           initializer=init_ops.zeros_initializer,
    694           dtype=dtypes.int32,
    695           shape=[num_children, self.num_groups],
    696           trainable=False)
    697 
    698     x = self.seq2seq_input_layer
    699     last_c, last_h, attn_mem = self.encode(x)
    700     actions, log_probs = {}, {}
    701     actions["sample"], log_probs["sample"] = (
    702         self.decode(
    703             x, last_c, last_h, attn_mem, mode="sample"))
    704     actions["target"], log_probs["target"] = (
    705         self.decode(
    706             x,
    707             last_c,
    708             last_h,
    709             attn_mem,
    710             mode="target",
    711             y=actions_cache))
    712     actions["greedy"], log_probs["greedy"] = (
    713         self.decode(
    714             x, last_c, last_h, attn_mem, mode="greedy"))
    715     actions["sample"] = control_flow_ops.cond(
    716         self.global_step < self.hparams.stop_sampling,
    717         lambda: state_ops.assign(actions_cache, actions["sample"]),
    718         lambda: state_ops.assign(actions_cache, actions["target"]))
    719     self.actions_cache = actions_cache
    720 
    721     return actions, log_probs
    722 
    723   def encode(self, x):
    724     """Encoder using LSTM.
    725 
    726     Args:
    727       x: tensor of size [num_children, num_groups, embedding_size]
    728 
    729     Returns:
    730       last_c, last_h: tensors of size [num_children, hidden_size], the final
    731         LSTM states
    732       attn_mem: tensor of size [num_children, num_groups, hidden_size], the
    733       attention
    734         memory, i.e. concatenation of all hidden states, linearly transformed by
    735         an attention matrix attn_w_1
    736     """
    737     if self.hparams.bi_lstm:
    738       with variable_scope.variable_scope(self.hparams.name, reuse=True):
    739         w_lstm_forward = variable_scope.get_variable("encoder_lstm_forward")
    740         w_lstm_backward = variable_scope.get_variable("encoder_lstm_backward")
    741         forget_bias = variable_scope.get_variable("encoder_forget_bias")
    742         attn_w_1 = variable_scope.get_variable("attn_w_1")
    743     else:
    744       with variable_scope.variable_scope(self.hparams.name, reuse=True):
    745         w_lstm = variable_scope.get_variable("encoder_lstm")
    746         forget_bias = variable_scope.get_variable("encoder_forget_bias")
    747         attn_w_1 = variable_scope.get_variable("attn_w_1")
    748 
    749     embedding_size = array_ops.shape(x)[2]
    750 
    751     signals = array_ops.split(x, self.num_groups, axis=1)
    752     for i in range(len(signals)):
    753       signals[i] = array_ops.reshape(
    754           signals[i], [self.hparams.num_children, embedding_size])
    755 
    756     if self.hparams.bi_lstm:
    757 
    758       def body(i, prev_c_forward, prev_h_forward, prev_c_backward,
    759                prev_h_backward):
    760         """while loop for LSTM."""
    761         signal_forward = signals[i]
    762         next_c_forward, next_h_forward = lstm(signal_forward, prev_c_forward,
    763                                               prev_h_forward, w_lstm_forward,
    764                                               forget_bias)
    765 
    766         signal_backward = signals[self.num_groups - 1 - i]
    767         next_c_backward, next_h_backward = lstm(
    768             signal_backward, prev_c_backward, prev_h_backward, w_lstm_backward,
    769             forget_bias)
    770 
    771         next_h = array_ops.concat([next_h_forward, next_h_backward], axis=1)
    772         all_h.append(next_h)
    773 
    774         return (next_c_forward, next_h_forward, next_c_backward,
    775                 next_h_backward)
    776 
    777       c_forward = array_ops.zeros(
    778           [self.hparams.num_children, self.hparams.hidden_size / 2],
    779           dtype=dtypes.float32)
    780       h_forward = array_ops.zeros(
    781           [self.hparams.num_children, self.hparams.hidden_size / 2],
    782           dtype=dtypes.float32)
    783 
    784       c_backward = array_ops.zeros(
    785           [self.hparams.num_children, self.hparams.hidden_size / 2],
    786           dtype=dtypes.float32)
    787       h_backward = array_ops.zeros(
    788           [self.hparams.num_children, self.hparams.hidden_size / 2],
    789           dtype=dtypes.float32)
    790       all_h = []
    791 
    792       for i in range(0, self.num_groups):
    793         c_forward, h_forward, c_backward, h_backward = body(
    794             i, c_forward, h_forward, c_backward, h_backward)
    795 
    796       last_c = array_ops.concat([c_forward, c_backward], axis=1)
    797       last_h = array_ops.concat([h_forward, h_backward], axis=1)
    798       attn_mem = array_ops.stack(all_h)
    799 
    800     else:
    801 
    802       def body(i, prev_c, prev_h):
    803         signal = signals[i]
    804         next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
    805         all_h.append(next_h)
    806         return next_c, next_h
    807 
    808       c = array_ops.zeros(
    809           [self.hparams.num_children, self.hparams.hidden_size],
    810           dtype=dtypes.float32)
    811       h = array_ops.zeros(
    812           [self.hparams.num_children, self.hparams.hidden_size],
    813           dtype=dtypes.float32)
    814       all_h = []
    815 
    816       for i in range(0, self.num_groups):
    817         c, h = body(i, c, h)
    818 
    819       last_c = c
    820       last_h = h
    821       attn_mem = array_ops.stack(all_h)
    822 
    823     attn_mem = array_ops.transpose(attn_mem, [1, 0, 2])
    824     attn_mem = array_ops.reshape(
    825         attn_mem,
    826         [self.hparams.num_children * self.num_groups, self.hparams.hidden_size])
    827     attn_mem = math_ops.matmul(attn_mem, attn_w_1)
    828     attn_mem = array_ops.reshape(
    829         attn_mem,
    830         [self.hparams.num_children, self.num_groups, self.hparams.hidden_size])
    831 
    832     return last_c, last_h, attn_mem
    833 
    834   def decode(self,
    835              x,
    836              last_c,
    837              last_h,
    838              attn_mem,
    839              mode="target",
    840              y=None):
    841     """Decoder using LSTM.
    842 
    843     Args:
    844       x: tensor of size [num_children, num_groups, embedding_size].
    845       last_c: tensor of size [num_children, hidden_size], the final LSTM states
    846           computed by self.encoder.
    847       last_h: same as last_c.
    848       attn_mem: tensor of size [num_children, num_groups, hidden_size].
    849       mode: "target" or "sample".
    850       y: tensor of size [num_children, num_groups], the device placements.
    851 
    852     Returns:
    853       actions: tensor of size [num_children, num_groups], the placements of
    854           devices
    855     """
    856     with variable_scope.variable_scope(self.hparams.name, reuse=True):
    857       w_lstm = variable_scope.get_variable("decoder_lstm")
    858       forget_bias = variable_scope.get_variable("decoder_forget_bias")
    859       device_embeddings = variable_scope.get_variable("device_embeddings")
    860       device_softmax = variable_scope.get_variable("device_softmax")
    861       device_go_embedding = variable_scope.get_variable("device_go_embedding")
    862       attn_w_2 = variable_scope.get_variable("attn_w_2")
    863       attn_v = variable_scope.get_variable("attn_v")
    864 
    865     actions = tensor_array_ops.TensorArray(
    866         dtypes.int32,
    867         size=self.num_groups,
    868         infer_shape=False,
    869         clear_after_read=False)
    870 
    871     # pylint: disable=unused-argument
    872     def condition(i, *args):
    873       return math_ops.less(i, self.num_groups)
    874 
    875     # pylint: disable=missing-docstring
    876     def body(i, prev_c, prev_h, actions, log_probs):
    877       # pylint: disable=g-long-lambda
    878       signal = control_flow_ops.cond(
    879           math_ops.equal(i, 0),
    880           lambda: array_ops.tile(device_go_embedding,
    881                                  [self.hparams.num_children, 1]),
    882           lambda: embedding_ops.embedding_lookup(device_embeddings,
    883                                                  actions.read(i - 1))
    884       )
    885       if self.hparams.keep_prob is not None:
    886         signal = nn_ops.dropout(signal, self.hparams.keep_prob)
    887       next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias)
    888       query = math_ops.matmul(next_h, attn_w_2)
    889       query = array_ops.reshape(
    890           query, [self.hparams.num_children, 1, self.hparams.hidden_size])
    891       query = math_ops.tanh(query + attn_mem)
    892       query = array_ops.reshape(query, [
    893           self.hparams.num_children * self.num_groups, self.hparams.hidden_size
    894       ])
    895       query = math_ops.matmul(query, attn_v)
    896       query = array_ops.reshape(query,
    897                                 [self.hparams.num_children, self.num_groups])
    898       query = nn_ops.softmax(query)
    899       query = array_ops.reshape(query,
    900                                 [self.hparams.num_children, self.num_groups, 1])
    901       query = math_ops.reduce_sum(attn_mem * query, axis=1)
    902       query = array_ops.concat([next_h, query], axis=1)
    903       logits = math_ops.matmul(query, device_softmax)
    904       logits /= self.hparams.temperature
    905       if self.hparams.tanh_constant > 0:
    906         logits = math_ops.tanh(logits) * self.hparams.tanh_constant
    907       if self.hparams.logits_std_noise > 0:
    908         num_in_logits = math_ops.cast(
    909             array_ops.size(logits), dtype=dtypes.float32)
    910         avg_norm = math_ops.divide(
    911             linalg_ops.norm(logits), math_ops.sqrt(num_in_logits))
    912         logits_noise = random_ops.random_normal(
    913             array_ops.shape(logits),
    914             stddev=self.hparams.logits_std_noise * avg_norm)
    915         logits = control_flow_ops.cond(
    916             self.global_step > self.hparams.stop_noise_step, lambda: logits,
    917             lambda: logits + logits_noise)
    918 
    919       if mode == "sample":
    920         next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed)
    921       elif mode == "greedy":
    922         next_y = math_ops.argmax(logits, 1)
    923       elif mode == "target":
    924         next_y = array_ops.slice(y, [0, i], [-1, 1])
    925       else:
    926         raise NotImplementedError
    927       next_y = math_ops.cast(next_y, dtypes.int32)
    928       next_y = array_ops.reshape(next_y, [self.hparams.num_children])
    929       actions = actions.write(i, next_y)
    930       log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits(
    931           logits=logits, labels=next_y)
    932       return i + 1, next_c, next_h, actions, log_probs
    933 
    934     loop_vars = [
    935         constant_op.constant(0, dtype=dtypes.int32), last_c, last_h, actions,
    936         array_ops.zeros([self.hparams.num_children], dtype=dtypes.float32)
    937     ]
    938     loop_outputs = control_flow_ops.while_loop(condition, body, loop_vars)
    939 
    940     last_c = loop_outputs[-4]
    941     last_h = loop_outputs[-3]
    942     actions = loop_outputs[-2].stack()
    943     actions = array_ops.transpose(actions, [1, 0])
    944     log_probs = loop_outputs[-1]
    945     return actions, log_probs
    946 
    947   def eval_placement(self,
    948                      sess,
    949                      child_id=0,
    950                      verbose=False):
    951     grouping_actions, actions = sess.run([
    952         self.grouping_actions_cache,
    953         self.actions_cache
    954     ])
    955     grouping_actions = grouping_actions[child_id]
    956     actions = actions[child_id]
    957     if verbose:
    958       global_step = sess.run(self.global_step)
    959       if global_step % 100 == 0:
    960         log_string = "op group assignments: "
    961         for a in grouping_actions:
    962           log_string += "{} ".format(a)
    963         print(log_string[:-1])
    964         log_string = "group device assignments: "
    965         for a in actions:
    966           log_string += "{} ".format(a)
    967         print(log_string[:-1])
    968 
    969     for op in self.important_ops:
    970       topo_order_index = self.name_to_topo_order_index[op.name]
    971       group_index = grouping_actions[topo_order_index]
    972       op.device = self.devices[actions[group_index]].name
    973     try:
    974       _, run_time, _ = self.cluster.MeasureCosts(self.item)
    975     except errors.ResourceExhaustedError:
    976       run_time = self.hparams.failing_signal
    977     return run_time
    978 
    979   def update_reward(self,
    980                     sess,
    981                     run_time,
    982                     child_id=0,
    983                     verbose=False):
    984     reward = self.compute_reward(run_time)
    985     controller_ops = self.ops["controller"]
    986     _, best_reward = sess.run(
    987         [
    988             controller_ops["reward"]["update"][child_id],
    989             controller_ops["best_reward"]["update"][child_id]
    990         ],
    991         feed_dict={
    992             controller_ops["reward"]["ph"][child_id]: reward,
    993         })
    994     if verbose:
    995       print(("run_time={:<.5f} reward={:<.5f} "
    996              "best_reward={:<.5f}").format(run_time, reward, best_reward))
    997 
    998     # Reward is a double, best_reward a float: allow for some slack in the
    999     # comparison.
   1000     updated = abs(best_reward - reward) < 1e-6
   1001     return updated
   1002 
   1003   def generate_grouping(self, sess):
   1004     controller_ops = self.ops["controller"]
   1005     grouping_actions = sess.run(controller_ops["grouping_y_preds"]["sample"])
   1006     return grouping_actions
   1007 
   1008   def generate_placement(self, grouping, sess):
   1009     controller_ops = self.ops["controller"]
   1010     feed_seq2seq_input_dict = {}
   1011     feed_seq2seq_input_dict[self.seq2seq_input_layer] = grouping
   1012     sess.run(
   1013         controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict)
   1014 
   1015   def process_reward(self, sess):
   1016     controller_ops = self.ops["controller"]
   1017     run_ops = [
   1018         controller_ops["loss"], controller_ops["lr"],
   1019         controller_ops["grad_norm"], controller_ops["grad_norms"],
   1020         controller_ops["train_op"]
   1021     ]
   1022     sess.run(run_ops)
   1023     sess.run(controller_ops["baseline_update"])
   1024 
   1025   def _get_train_ops(self,
   1026                      loss,
   1027                      tf_variables,
   1028                      global_step,
   1029                      grad_bound=1.25,
   1030                      lr_init=1e-3,
   1031                      lr_dec=0.9,
   1032                      start_decay_step=10000,
   1033                      decay_steps=100,
   1034                      optimizer_type="adam"):
   1035     """Loss optimizer.
   1036 
   1037     Args:
   1038       loss: scalar tf tensor
   1039       tf_variables: list of training variables, typically
   1040         tf.trainable_variables()
   1041       global_step: global_step
   1042       grad_bound: max gradient norm
   1043       lr_init: initial learning rate
   1044       lr_dec: leaning rate decay coefficient
   1045       start_decay_step: start decaying learning rate after this many steps
   1046       decay_steps: apply decay rate factor at this step intervals
   1047       optimizer_type: optimizer type should be either adam or sgd
   1048 
   1049     Returns:
   1050       train_op: training op
   1051       learning_rate: scalar learning rate tensor
   1052       grad_norm: l2 norm of the gradient vector
   1053       all_grad_norms: l2 norm of each component
   1054     """
   1055     lr_gstep = global_step - start_decay_step
   1056 
   1057     def f1():
   1058       return constant_op.constant(lr_init)
   1059 
   1060     def f2():
   1061       return learning_rate_decay.exponential_decay(lr_init, lr_gstep,
   1062                                                    decay_steps, lr_dec, True)
   1063 
   1064     learning_rate = control_flow_ops.cond(
   1065         math_ops.less(global_step, start_decay_step),
   1066         f1,
   1067         f2,
   1068         name="learning_rate")
   1069 
   1070     if optimizer_type == "adam":
   1071       opt = adam.AdamOptimizer(learning_rate)
   1072     elif optimizer_type == "sgd":
   1073       opt = gradient_descent.GradientDescentOptimizer(learning_rate)
   1074     grads_and_vars = opt.compute_gradients(loss, tf_variables)
   1075     grad_norm = clip_ops.global_norm([g for g, v in grads_and_vars])
   1076     all_grad_norms = {}
   1077     clipped_grads = []
   1078     clipped_rate = math_ops.maximum(grad_norm / grad_bound, 1.0)
   1079     for g, v in grads_and_vars:
   1080       if g is not None:
   1081         if isinstance(g, tf_ops.IndexedSlices):
   1082           clipped = g.values / clipped_rate
   1083           norm_square = math_ops.reduce_sum(clipped * clipped)
   1084           clipped = tf_ops.IndexedSlices(clipped, g.indices)
   1085         else:
   1086           clipped = g / clipped_rate
   1087           norm_square = math_ops.reduce_sum(clipped * clipped)
   1088         all_grad_norms[v.name] = math_ops.sqrt(norm_square)
   1089         clipped_grads.append((clipped, v))
   1090 
   1091     train_op = opt.apply_gradients(clipped_grads, global_step)
   1092     return train_op, learning_rate, grad_norm, all_grad_norms
   1093 
   1094 
   1095 def lstm(x, prev_c, prev_h, w_lstm, forget_bias):
   1096   """LSTM cell.
   1097 
   1098   Args:
   1099     x: tensors of size [num_children, hidden_size].
   1100     prev_c: tensors of size [num_children, hidden_size].
   1101     prev_h: same as prev_c.
   1102     w_lstm: .
   1103     forget_bias: .
   1104 
   1105   Returns:
   1106     next_c:
   1107     next_h:
   1108   """
   1109   ifog = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w_lstm)
   1110   i, f, o, g = array_ops.split(ifog, 4, axis=1)
   1111   i = math_ops.sigmoid(i)
   1112   f = math_ops.sigmoid(f + forget_bias)
   1113   o = math_ops.sigmoid(o)
   1114   g = math_ops.tanh(g)
   1115   next_c = i * g + f * prev_c
   1116   next_h = o * math_ops.tanh(next_c)
   1117   return next_c, next_h
   1118