1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """HierarchicalController Class. 16 17 The HierarchicalController encompasses the entire lifecycle of training the 18 device placement policy, including generating op embeddings, getting groups for 19 each op, placing those groups and running the predicted placements. 20 21 Different assignment models can inherit from this class. 22 """ 23 24 from __future__ import absolute_import 25 from __future__ import division 26 from __future__ import print_function 27 28 import math 29 import numpy as np 30 import six 31 from tensorflow.python.framework import constant_op 32 from tensorflow.python.framework import dtypes 33 from tensorflow.python.framework import errors 34 from tensorflow.python.framework import ops as tf_ops 35 from tensorflow.python.grappler.controller import Controller 36 from tensorflow.python.ops import array_ops 37 from tensorflow.python.ops import clip_ops 38 from tensorflow.python.ops import control_flow_ops 39 from tensorflow.python.ops import embedding_ops 40 from tensorflow.python.ops import init_ops 41 from tensorflow.python.ops import linalg_ops 42 from tensorflow.python.ops import math_ops 43 from tensorflow.python.ops import nn_ops 44 from tensorflow.python.ops import random_ops 45 from tensorflow.python.ops import state_ops 46 from tensorflow.python.ops import tensor_array_ops 47 from tensorflow.python.ops import variable_scope 48 from tensorflow.python.summary import summary 49 from tensorflow.python.training import adam 50 from tensorflow.python.training import gradient_descent 51 from tensorflow.python.training import learning_rate_decay 52 from tensorflow.python.training import training_util 53 54 55 class PlacerParams(object): 56 """Class to hold a set of placement parameters as name-value pairs. 57 58 A typical usage is as follows: 59 60 ```python 61 # Create a PlacerParams object specifying names and values of the model 62 # parameters: 63 params = PlacerParams(hidden_size=128, decay_steps=50) 64 65 # The parameters are available as attributes of the PlacerParams object: 66 hparams.hidden_size ==> 128 67 hparams.decay_steps ==> 50 68 ``` 69 70 """ 71 72 def __init__(self, **kwargs): 73 """Create an instance of `PlacerParams` from keyword arguments. 74 75 The keyword arguments specify name-values pairs for the parameters. 76 The parameter types are inferred from the type of the values passed. 77 78 The parameter names are added as attributes of `PlacerParams` object, 79 and they can be accessed directly with the dot notation `params._name_`. 80 81 Example: 82 83 ```python 84 # Define 1 parameter: 'hidden_size' 85 params = PlacerParams(hidden_size=128) 86 params.hidden_size ==> 128 87 ``` 88 89 Args: 90 **kwargs: Key-value pairs where the key is the parameter name and 91 the value is the value for the parameter. 92 """ 93 for name, value in six.iteritems(kwargs): 94 self.add_param(name, value) 95 96 def add_param(self, name, value): 97 """Adds {name, value} pair to hyperparameters. 98 99 Args: 100 name: Name of the hyperparameter. 101 value: Value of the hyperparameter. Can be one of the following types: 102 int, float, string, int list, float list, or string list. 103 104 Raises: 105 ValueError: if one of the arguments is invalid. 106 """ 107 # Keys in kwargs are unique, but 'name' could be the name of a pre-existing 108 # attribute of this object. In that case we refuse to use it as a 109 # parameter name. 110 if getattr(self, name, None) is not None: 111 raise ValueError("Parameter name is reserved: %s" % name) 112 setattr(self, name, value) 113 114 115 def hierarchical_controller_hparams(): 116 """Hyperparameters for hierarchical planner.""" 117 return PlacerParams( 118 hidden_size=512, 119 forget_bias_init=1.0, 120 temperature=1.0, 121 logits_std_noise=0.5, 122 stop_noise_step=750, 123 decay_steps=50, 124 max_num_outputs=5, 125 max_output_size=5, 126 tanh_constant=1.0, 127 adj_embed_dim=20, 128 grouping_hidden_size=64, 129 num_groups=None, 130 bi_lstm=True, 131 failing_signal=100, 132 stop_sampling=500, 133 start_with_failing_signal=True, 134 always_update_baseline=False, 135 bl_dec=0.9, 136 grad_bound=1.0, 137 lr=0.1, 138 lr_dec=0.95, 139 start_decay_step=400, 140 optimizer_type="adam", 141 stop_updating_after_steps=1000, 142 name="hierarchical_controller", 143 keep_prob=1.0, 144 reward_function="sqrt", 145 seed=1234, 146 # distributed training params 147 num_children=1) 148 149 150 class HierarchicalController(Controller): 151 """HierarchicalController class.""" 152 153 def __init__(self, hparams, item, cluster, controller_id=0): 154 """HierarchicalController class initializer. 155 156 Args: 157 hparams: All hyper-parameters. 158 item: The metagraph to place. 159 cluster: The cluster of hardware devices to optimize for. 160 controller_id: the id of the controller in a multi-controller setup. 161 """ 162 super(HierarchicalController, self).__init__(item, cluster) 163 self.ctrl_id = controller_id 164 self.hparams = hparams 165 166 if self.hparams.num_groups is None: 167 self.num_groups = min(256, 20 * self.num_devices) 168 else: 169 self.num_groups = self.hparams.num_groups 170 171 # creates self.op_embeddings and self.type_dict 172 self.create_op_embeddings(verbose=False) 173 # TODO(azalia) clean up embedding/group_embedding_size names 174 self.group_emb_size = ( 175 2 * self.num_groups + len(self.type_dict) + 176 self.hparams.max_num_outputs * self.hparams.max_output_size) 177 self.embedding_size = self.group_emb_size 178 self.initializer = init_ops.glorot_uniform_initializer( 179 seed=self.hparams.seed) 180 181 with variable_scope.variable_scope( 182 self.hparams.name, 183 initializer=self.initializer, 184 reuse=variable_scope.AUTO_REUSE): 185 # define parameters of feedforward 186 variable_scope.get_variable("w_grouping_ff", [ 187 1 + self.hparams.max_num_outputs * self.hparams.max_output_size + 188 self.hparams.adj_embed_dim, self.hparams.grouping_hidden_size 189 ]) 190 variable_scope.get_variable( 191 "w_grouping_softmax", 192 [self.hparams.grouping_hidden_size, self.num_groups]) 193 if self.hparams.bi_lstm: 194 variable_scope.get_variable("encoder_lstm_forward", [ 195 self.embedding_size + self.hparams.hidden_size / 2, 196 2 * self.hparams.hidden_size 197 ]) 198 variable_scope.get_variable("encoder_lstm_backward", [ 199 self.embedding_size + self.hparams.hidden_size / 2, 200 2 * self.hparams.hidden_size 201 ]) 202 variable_scope.get_variable( 203 "device_embeddings", [self.num_devices, self.hparams.hidden_size]) 204 variable_scope.get_variable( 205 "decoder_lstm", 206 [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size]) 207 variable_scope.get_variable( 208 "device_softmax", [2 * self.hparams.hidden_size, self.num_devices]) 209 variable_scope.get_variable("device_go_embedding", 210 [1, self.hparams.hidden_size]) 211 variable_scope.get_variable( 212 "encoder_forget_bias", 213 shape=1, 214 dtype=dtypes.float32, 215 initializer=init_ops.constant_initializer( 216 self.hparams.forget_bias_init)) 217 variable_scope.get_variable( 218 "decoder_forget_bias", 219 shape=1, 220 dtype=dtypes.float32, 221 initializer=init_ops.constant_initializer( 222 self.hparams.forget_bias_init)) 223 variable_scope.get_variable( 224 "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size]) 225 variable_scope.get_variable( 226 "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) 227 variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) 228 229 else: 230 variable_scope.get_variable("encoder_lstm", [ 231 self.embedding_size + self.hparams.hidden_size, 232 4 * self.hparams.hidden_size 233 ]) 234 variable_scope.get_variable( 235 "device_embeddings", [self.num_devices, self.hparams.hidden_size]) 236 variable_scope.get_variable( 237 "decoder_lstm", 238 [2 * self.hparams.hidden_size, 4 * self.hparams.hidden_size]) 239 variable_scope.get_variable( 240 "device_softmax", [2 * self.hparams.hidden_size, self.num_devices]) 241 variable_scope.get_variable("device_go_embedding", 242 [1, self.hparams.hidden_size]) 243 variable_scope.get_variable( 244 "encoder_forget_bias", 245 shape=1, 246 dtype=dtypes.float32, 247 initializer=init_ops.constant_initializer( 248 self.hparams.forget_bias_init)) 249 variable_scope.get_variable( 250 "decoder_forget_bias", 251 shape=1, 252 dtype=dtypes.float32, 253 initializer=init_ops.constant_initializer( 254 self.hparams.forget_bias_init)) 255 variable_scope.get_variable( 256 "attn_w_1", [self.hparams.hidden_size, self.hparams.hidden_size]) 257 variable_scope.get_variable( 258 "attn_w_2", [self.hparams.hidden_size, self.hparams.hidden_size]) 259 variable_scope.get_variable("attn_v", [self.hparams.hidden_size, 1]) 260 seq2seq_input_layer = array_ops.placeholder_with_default( 261 array_ops.zeros([self.hparams.num_children, 262 self.num_groups, 263 self.group_emb_size], 264 dtypes.float32), 265 shape=(self.hparams.num_children, self.num_groups, self.group_emb_size)) 266 self.seq2seq_input_layer = seq2seq_input_layer 267 268 def compute_reward(self, run_time): 269 if self.hparams.reward_function == "id": 270 reward = run_time 271 elif self.hparams.reward_function == "sqrt": 272 reward = math.sqrt(run_time) 273 elif self.hparams.reward_function == "log": 274 reward = math.log1p(run_time) 275 else: 276 raise NotImplementedError( 277 "Unrecognized reward function '%s', consider your " 278 "--reward_function flag value." % self.hparams.reward_function) 279 return reward 280 281 def build_controller(self): 282 """RL optimization interface. 283 284 Returns: 285 ops: A dictionary holding handles of the model used for training. 286 """ 287 288 self._global_step = training_util.get_or_create_global_step() 289 ops = {} 290 ops["loss"] = 0 291 292 failing_signal = self.compute_reward(self.hparams.failing_signal) 293 294 ctr = {} 295 296 with tf_ops.name_scope("controller_{}".format(self.ctrl_id)): 297 with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): 298 ctr["reward"] = {"value": [], "ph": [], "update": []} 299 ctr["ready"] = {"value": [], "ph": [], "update": []} 300 ctr["best_reward"] = {"value": [], "update": []} 301 for i in range(self.hparams.num_children): 302 reward_value = variable_scope.get_local_variable( 303 "reward_{}".format(i), 304 initializer=0.0, 305 dtype=dtypes.float32, 306 trainable=False) 307 reward_ph = array_ops.placeholder( 308 dtypes.float32, shape=(), name="reward_ph_{}".format(i)) 309 reward_update = state_ops.assign( 310 reward_value, reward_ph, use_locking=True) 311 ctr["reward"]["value"].append(reward_value) 312 ctr["reward"]["ph"].append(reward_ph) 313 ctr["reward"]["update"].append(reward_update) 314 best_reward = variable_scope.get_local_variable( 315 "best_reward_{}".format(i), 316 initializer=failing_signal, 317 dtype=dtypes.float32, 318 trainable=False) 319 ctr["best_reward"]["value"].append(best_reward) 320 ctr["best_reward"]["update"].append( 321 state_ops.assign(best_reward, 322 math_ops.minimum(best_reward, reward_update))) 323 324 ready_value = variable_scope.get_local_variable( 325 "ready_{}".format(i), 326 initializer=True, 327 dtype=dtypes.bool, 328 trainable=False) 329 ready_ph = array_ops.placeholder( 330 dtypes.bool, shape=(), name="ready_ph_{}".format(i)) 331 ready_update = state_ops.assign( 332 ready_value, ready_ph, use_locking=True) 333 ctr["ready"]["value"].append(ready_value) 334 ctr["ready"]["ph"].append(ready_ph) 335 ctr["ready"]["update"].append(ready_update) 336 337 ctr["grouping_y_preds"], ctr["grouping_log_probs"] = self.get_groupings() 338 summary.histogram( 339 "grouping_actions", 340 array_ops.slice(ctr["grouping_y_preds"]["sample"], [0, 0], 341 [1, array_ops.shape(self.op_embeddings)[0]])) 342 343 with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): 344 ctr["baseline"] = variable_scope.get_local_variable( 345 "baseline", 346 initializer=failing_signal 347 if self.hparams.start_with_failing_signal else 0.0, 348 dtype=dtypes.float32, 349 trainable=False) 350 351 new_baseline = self.hparams.bl_dec * ctr["baseline"] + ( 352 1 - self.hparams.bl_dec) * math_ops.reduce_mean( 353 ctr["reward"]["value"]) 354 if not self.hparams.always_update_baseline: 355 baseline_mask = math_ops.less(ctr["reward"]["value"], failing_signal) 356 selected_reward = array_ops.boolean_mask(ctr["reward"]["value"], 357 baseline_mask) 358 selected_baseline = control_flow_ops.cond( 359 math_ops.reduce_any(baseline_mask), 360 lambda: math_ops.reduce_mean(selected_reward), 361 lambda: constant_op.constant(0, dtype=dtypes.float32)) 362 ctr["pos_reward"] = selected_baseline 363 pos_ = math_ops.less( 364 constant_op.constant(0, dtype=dtypes.float32), selected_baseline) 365 selected_baseline = self.hparams.bl_dec * ctr["baseline"] + ( 366 1 - self.hparams.bl_dec) * selected_baseline 367 selected_baseline = control_flow_ops.cond( 368 pos_, lambda: selected_baseline, lambda: ctr["baseline"]) 369 new_baseline = control_flow_ops.cond( 370 math_ops.less(self.global_step, 371 self.hparams.stop_updating_after_steps), 372 lambda: new_baseline, lambda: selected_baseline) 373 ctr["baseline_update"] = state_ops.assign( 374 ctr["baseline"], new_baseline, use_locking=True) 375 376 ctr["y_preds"], ctr["log_probs"] = self.get_placements() 377 summary.histogram("actions", ctr["y_preds"]["sample"]) 378 mask = math_ops.less(ctr["reward"]["value"], failing_signal) 379 ctr["loss"] = ctr["reward"]["value"] - ctr["baseline"] 380 ctr["loss"] *= ( 381 ctr["log_probs"]["sample"] + ctr["grouping_log_probs"]["sample"]) 382 383 selected_loss = array_ops.boolean_mask(ctr["loss"], mask) 384 selected_loss = control_flow_ops.cond( 385 math_ops.reduce_any(mask), 386 lambda: math_ops.reduce_mean(-selected_loss), 387 lambda: constant_op.constant(0, dtype=dtypes.float32)) 388 389 ctr["loss"] = control_flow_ops.cond( 390 math_ops.less(self.global_step, 391 self.hparams.stop_updating_after_steps), 392 lambda: math_ops.reduce_mean(-ctr["loss"]), lambda: selected_loss) 393 394 ctr["reward_s"] = math_ops.reduce_mean(ctr["reward"]["value"]) 395 summary.scalar("loss", ctr["loss"]) 396 summary.scalar("avg_reward", ctr["reward_s"]) 397 summary.scalar("best_reward_so_far", best_reward) 398 summary.scalar( 399 "advantage", 400 math_ops.reduce_mean(ctr["reward"]["value"] - ctr["baseline"])) 401 402 with variable_scope.variable_scope( 403 "optimizer", reuse=variable_scope.AUTO_REUSE): 404 (ctr["train_op"], ctr["lr"], ctr["grad_norm"], 405 ctr["grad_norms"]) = self._get_train_ops( 406 ctr["loss"], 407 tf_ops.get_collection(tf_ops.GraphKeys.TRAINABLE_VARIABLES), 408 self.global_step, 409 grad_bound=self.hparams.grad_bound, 410 lr_init=self.hparams.lr, 411 lr_dec=self.hparams.lr_dec, 412 start_decay_step=self.hparams.start_decay_step, 413 decay_steps=self.hparams.decay_steps, 414 optimizer_type=self.hparams.optimizer_type) 415 416 summary.scalar("gradnorm", ctr["grad_norm"]) 417 summary.scalar("lr", ctr["lr"]) 418 ctr["summary"] = summary.merge_all() 419 ops["controller"] = ctr 420 421 self.ops = ops 422 return ops 423 424 @property 425 def global_step(self): 426 return self._global_step 427 428 def create_op_embeddings(self, verbose=False): 429 if verbose: 430 print("process input graph for op embeddings") 431 self.num_ops = len(self.important_ops) 432 # topological sort of important nodes 433 topo_order = [op.name for op in self.important_ops] 434 435 # create index to name for topologicaly sorted important nodes 436 name_to_topo_order_index = {} 437 for idx, x in enumerate(topo_order): 438 name_to_topo_order_index[x] = idx 439 self.name_to_topo_order_index = name_to_topo_order_index 440 441 # create adj matrix 442 adj_dict = {} 443 for idx, op in enumerate(self.important_ops): 444 for output_op in self.get_node_fanout(op): 445 output_op_name = output_op.name 446 if output_op_name in self.important_op_names: 447 if name_to_topo_order_index[op.name] not in adj_dict: 448 adj_dict[name_to_topo_order_index[op.name]] = [] 449 adj_dict[name_to_topo_order_index[op.name]].extend( 450 [name_to_topo_order_index[output_op_name], 1]) 451 if output_op_name not in adj_dict: 452 adj_dict[name_to_topo_order_index[output_op_name]] = [] 453 adj_dict[name_to_topo_order_index[output_op_name]].extend( 454 [name_to_topo_order_index[op.name], -1]) 455 456 # get op_type op_output_shape, and adj info 457 output_embed_dim = (self.hparams.max_num_outputs * 458 self.hparams.max_output_size) 459 460 # TODO(bsteiner): don't filter based on used ops so that we can generalize 461 # to models that use other types of ops. 462 used_ops = set() 463 for node in self.important_ops: 464 op_type = str(node.op) 465 used_ops.add(op_type) 466 467 self.type_dict = {} 468 for op_type in self.cluster.ListAvailableOps(): 469 if op_type in used_ops: 470 self.type_dict[op_type] = len(self.type_dict) 471 472 op_types = np.zeros([self.num_ops], dtype=np.int32) 473 op_output_shapes = np.full( 474 [self.num_ops, output_embed_dim], -1.0, dtype=np.float32) 475 for idx, node in enumerate(self.important_ops): 476 op_types[idx] = self.type_dict[node.op] 477 # output shape 478 op_name = node.name 479 for i, output_prop in enumerate(self.node_properties[op_name]): 480 if output_prop.shape.__str__() == "<unknown>": 481 continue 482 shape = output_prop.shape 483 for j, dim in enumerate(shape.dim): 484 if dim.size >= 0: 485 if i * self.hparams.max_output_size + j >= output_embed_dim: 486 break 487 op_output_shapes[idx, 488 i * self.hparams.max_output_size + j] = dim.size 489 # adj for padding 490 op_adj = np.full( 491 [self.num_ops, self.hparams.adj_embed_dim], 0, dtype=np.float32) 492 for idx in adj_dict: 493 neighbors = adj_dict[int(idx)] 494 min_dim = min(self.hparams.adj_embed_dim, len(neighbors)) 495 padding_size = self.hparams.adj_embed_dim - min_dim 496 neighbors = neighbors[:min_dim] + [0] * padding_size 497 op_adj[int(idx)] = neighbors 498 499 # op_embedding starts here 500 op_embeddings = np.zeros( 501 [ 502 self.num_ops, 503 1 + self.hparams.max_num_outputs * self.hparams.max_output_size + 504 self.hparams.adj_embed_dim 505 ], 506 dtype=np.float32) 507 for idx, op_name in enumerate(topo_order): 508 op_embeddings[idx] = np.concatenate( 509 (np.array([op_types[idx]]), op_output_shapes[idx], op_adj[int(idx)])) 510 self.op_embeddings = constant_op.constant( 511 op_embeddings, dtype=dtypes.float32) 512 if verbose: 513 print("num_ops = {}".format(self.num_ops)) 514 print("num_types = {}".format(len(self.type_dict))) 515 516 def get_groupings(self, *args, **kwargs): 517 num_children = self.hparams.num_children 518 with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): 519 grouping_actions_cache = variable_scope.get_local_variable( 520 "grouping_actions_cache", 521 initializer=init_ops.zeros_initializer, 522 dtype=dtypes.int32, 523 shape=[num_children, self.num_ops], 524 trainable=False) 525 input_layer = self.op_embeddings 526 input_layer = array_ops.expand_dims(input_layer, 0) 527 feed_ff_input_layer = array_ops.tile(input_layer, [num_children, 1, 1]) 528 grouping_actions, grouping_log_probs = {}, {} 529 grouping_actions["sample"], grouping_log_probs[ 530 "sample"] = self.make_grouping_predictions(feed_ff_input_layer) 531 532 grouping_actions["sample"] = state_ops.assign(grouping_actions_cache, 533 grouping_actions["sample"]) 534 self.grouping_actions_cache = grouping_actions_cache 535 536 return grouping_actions, grouping_log_probs 537 538 def make_grouping_predictions(self, input_layer, reuse=None): 539 """model that predicts grouping (grouping_actions). 540 541 Args: 542 input_layer: group_input_layer 543 reuse: reuse 544 545 Returns: 546 grouping_actions: actions 547 grouping_log_probs: log probabilities corresponding to actions 548 """ 549 with variable_scope.variable_scope(self.hparams.name, reuse=True): 550 # input_layer: tensor of size [1, num_ops, hidden_size] 551 w_grouping_ff = variable_scope.get_variable("w_grouping_ff") 552 w_grouping_softmax = variable_scope.get_variable("w_grouping_softmax") 553 554 batch_size = array_ops.shape(input_layer)[0] 555 embedding_dim = array_ops.shape(input_layer)[2] 556 557 reshaped = array_ops.reshape(input_layer, 558 [batch_size * self.num_ops, embedding_dim]) 559 ff_output = math_ops.matmul(reshaped, w_grouping_ff) 560 logits = math_ops.matmul(ff_output, w_grouping_softmax) 561 if self.hparams.logits_std_noise > 0: 562 num_in_logits = math_ops.cast( 563 array_ops.size(logits), dtype=dtypes.float32) 564 avg_norm = math_ops.divide( 565 linalg_ops.norm(logits), math_ops.sqrt(num_in_logits)) 566 logits_noise = random_ops.random_normal( 567 array_ops.shape(logits), 568 stddev=self.hparams.logits_std_noise * avg_norm) 569 logits = control_flow_ops.cond( 570 self.global_step > self.hparams.stop_noise_step, lambda: logits, 571 lambda: logits + logits_noise) 572 logits = array_ops.reshape(logits, 573 [batch_size * self.num_ops, self.num_groups]) 574 actions = random_ops.multinomial(logits, 1, seed=self.hparams.seed) 575 actions = math_ops.cast(actions, dtypes.int32) 576 actions = array_ops.reshape(actions, [batch_size, self.num_ops]) 577 action_label = array_ops.reshape(actions, [-1]) 578 log_probs = nn_ops.sparse_softmax_cross_entropy_with_logits( 579 logits=logits, labels=action_label) 580 log_probs = array_ops.reshape(log_probs, [batch_size, -1]) 581 log_probs = math_ops.reduce_sum(log_probs, 1) 582 grouping_actions = actions 583 grouping_log_probs = log_probs 584 return grouping_actions, grouping_log_probs 585 586 def create_group_embeddings(self, grouping_actions, verbose=False): 587 """Approximating the blocks of a TF graph from a graph_def. 588 589 Args: 590 grouping_actions: grouping predictions. 591 verbose: print stuffs. 592 593 Returns: 594 groups: list of groups. 595 """ 596 groups = [ 597 self._create_group_embeddings(grouping_actions, i, verbose) for 598 i in range(self.hparams.num_children) 599 ] 600 return np.stack(groups, axis=0) 601 602 def _create_group_embeddings(self, grouping_actions, child_id, verbose=False): 603 """Approximating the blocks of a TF graph from a graph_def for each child. 604 605 Args: 606 grouping_actions: grouping predictions. 607 child_id: child_id for the group. 608 verbose: print stuffs. 609 610 Returns: 611 groups: group embedding for the child_id. 612 """ 613 if verbose: 614 print("Processing input_graph") 615 616 # TODO(azalia): Build inter-adjacencies dag matrix. 617 # record dag_matrix 618 dag_matrix = np.zeros([self.num_groups, self.num_groups], dtype=np.float32) 619 for op in self.important_ops: 620 topo_op_index = self.name_to_topo_order_index[op.name] 621 group_index = grouping_actions[child_id][topo_op_index] 622 for output_op in self.get_node_fanout(op): 623 if output_op.name not in self.important_op_names: 624 continue 625 output_group_index = ( 626 grouping_actions[child_id][self.name_to_topo_order_index[ 627 output_op.name]]) 628 dag_matrix[group_index, output_group_index] += 1.0 629 num_connections = np.sum(dag_matrix) 630 num_intra_group_connections = dag_matrix.trace() 631 num_inter_group_connections = num_connections - num_intra_group_connections 632 if verbose: 633 print("grouping evaluation metric") 634 print(("num_connections={} num_intra_group_connections={} " 635 "num_inter_group_connections={}").format( 636 num_connections, num_intra_group_connections, 637 num_inter_group_connections)) 638 self.dag_matrix = dag_matrix 639 640 # output_shape 641 op_output_shapes = np.zeros( 642 [ 643 len(self.important_ops), 644 self.hparams.max_num_outputs * self.hparams.max_output_size 645 ], 646 dtype=np.float32) 647 648 for idx, op in enumerate(self.important_ops): 649 for i, output_properties in enumerate(self.node_properties[op.name]): 650 if output_properties.shape.__str__() == "<unknown>": 651 continue 652 if i > self.hparams.max_num_outputs: 653 break 654 shape = output_properties.shape 655 for j, dim in enumerate(shape.dim): 656 if dim.size > 0: 657 k = i * self.hparams.max_output_size + j 658 if k >= self.hparams.max_num_outputs * self.hparams.max_output_size: 659 break 660 op_output_shapes[idx, k] = dim.size 661 662 # group_embedding 663 group_embedding = np.zeros( 664 [ 665 self.num_groups, len(self.type_dict) + 666 self.hparams.max_num_outputs * self.hparams.max_output_size 667 ], 668 dtype=np.float32) 669 for op_index, op in enumerate(self.important_ops): 670 group_index = grouping_actions[child_id][ 671 self.name_to_topo_order_index[op.name]] 672 type_name = str(op.op) 673 type_index = self.type_dict[type_name] 674 group_embedding[group_index, type_index] += 1 675 group_embedding[group_index, :self.hparams.max_num_outputs * self.hparams. 676 max_output_size] += ( 677 op_output_shapes[op_index]) 678 grouping_adjacencies = np.concatenate( 679 [dag_matrix, np.transpose(dag_matrix)], axis=1) 680 group_embedding = np.concatenate( 681 [grouping_adjacencies, group_embedding], axis=1) 682 group_normalizer = np.amax(group_embedding, axis=1, keepdims=True) 683 group_embedding /= (group_normalizer + 1.0) 684 if verbose: 685 print("Finished Processing Input Graph") 686 return group_embedding 687 688 def get_placements(self, *args, **kwargs): 689 num_children = self.hparams.num_children 690 with variable_scope.variable_scope("controller_{}".format(self.ctrl_id)): 691 actions_cache = variable_scope.get_local_variable( 692 "actions_cache", 693 initializer=init_ops.zeros_initializer, 694 dtype=dtypes.int32, 695 shape=[num_children, self.num_groups], 696 trainable=False) 697 698 x = self.seq2seq_input_layer 699 last_c, last_h, attn_mem = self.encode(x) 700 actions, log_probs = {}, {} 701 actions["sample"], log_probs["sample"] = ( 702 self.decode( 703 x, last_c, last_h, attn_mem, mode="sample")) 704 actions["target"], log_probs["target"] = ( 705 self.decode( 706 x, 707 last_c, 708 last_h, 709 attn_mem, 710 mode="target", 711 y=actions_cache)) 712 actions["greedy"], log_probs["greedy"] = ( 713 self.decode( 714 x, last_c, last_h, attn_mem, mode="greedy")) 715 actions["sample"] = control_flow_ops.cond( 716 self.global_step < self.hparams.stop_sampling, 717 lambda: state_ops.assign(actions_cache, actions["sample"]), 718 lambda: state_ops.assign(actions_cache, actions["target"])) 719 self.actions_cache = actions_cache 720 721 return actions, log_probs 722 723 def encode(self, x): 724 """Encoder using LSTM. 725 726 Args: 727 x: tensor of size [num_children, num_groups, embedding_size] 728 729 Returns: 730 last_c, last_h: tensors of size [num_children, hidden_size], the final 731 LSTM states 732 attn_mem: tensor of size [num_children, num_groups, hidden_size], the 733 attention 734 memory, i.e. concatenation of all hidden states, linearly transformed by 735 an attention matrix attn_w_1 736 """ 737 if self.hparams.bi_lstm: 738 with variable_scope.variable_scope(self.hparams.name, reuse=True): 739 w_lstm_forward = variable_scope.get_variable("encoder_lstm_forward") 740 w_lstm_backward = variable_scope.get_variable("encoder_lstm_backward") 741 forget_bias = variable_scope.get_variable("encoder_forget_bias") 742 attn_w_1 = variable_scope.get_variable("attn_w_1") 743 else: 744 with variable_scope.variable_scope(self.hparams.name, reuse=True): 745 w_lstm = variable_scope.get_variable("encoder_lstm") 746 forget_bias = variable_scope.get_variable("encoder_forget_bias") 747 attn_w_1 = variable_scope.get_variable("attn_w_1") 748 749 embedding_size = array_ops.shape(x)[2] 750 751 signals = array_ops.split(x, self.num_groups, axis=1) 752 for i in range(len(signals)): 753 signals[i] = array_ops.reshape( 754 signals[i], [self.hparams.num_children, embedding_size]) 755 756 if self.hparams.bi_lstm: 757 758 def body(i, prev_c_forward, prev_h_forward, prev_c_backward, 759 prev_h_backward): 760 """while loop for LSTM.""" 761 signal_forward = signals[i] 762 next_c_forward, next_h_forward = lstm(signal_forward, prev_c_forward, 763 prev_h_forward, w_lstm_forward, 764 forget_bias) 765 766 signal_backward = signals[self.num_groups - 1 - i] 767 next_c_backward, next_h_backward = lstm( 768 signal_backward, prev_c_backward, prev_h_backward, w_lstm_backward, 769 forget_bias) 770 771 next_h = array_ops.concat([next_h_forward, next_h_backward], axis=1) 772 all_h.append(next_h) 773 774 return (next_c_forward, next_h_forward, next_c_backward, 775 next_h_backward) 776 777 c_forward = array_ops.zeros( 778 [self.hparams.num_children, self.hparams.hidden_size / 2], 779 dtype=dtypes.float32) 780 h_forward = array_ops.zeros( 781 [self.hparams.num_children, self.hparams.hidden_size / 2], 782 dtype=dtypes.float32) 783 784 c_backward = array_ops.zeros( 785 [self.hparams.num_children, self.hparams.hidden_size / 2], 786 dtype=dtypes.float32) 787 h_backward = array_ops.zeros( 788 [self.hparams.num_children, self.hparams.hidden_size / 2], 789 dtype=dtypes.float32) 790 all_h = [] 791 792 for i in range(0, self.num_groups): 793 c_forward, h_forward, c_backward, h_backward = body( 794 i, c_forward, h_forward, c_backward, h_backward) 795 796 last_c = array_ops.concat([c_forward, c_backward], axis=1) 797 last_h = array_ops.concat([h_forward, h_backward], axis=1) 798 attn_mem = array_ops.stack(all_h) 799 800 else: 801 802 def body(i, prev_c, prev_h): 803 signal = signals[i] 804 next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias) 805 all_h.append(next_h) 806 return next_c, next_h 807 808 c = array_ops.zeros( 809 [self.hparams.num_children, self.hparams.hidden_size], 810 dtype=dtypes.float32) 811 h = array_ops.zeros( 812 [self.hparams.num_children, self.hparams.hidden_size], 813 dtype=dtypes.float32) 814 all_h = [] 815 816 for i in range(0, self.num_groups): 817 c, h = body(i, c, h) 818 819 last_c = c 820 last_h = h 821 attn_mem = array_ops.stack(all_h) 822 823 attn_mem = array_ops.transpose(attn_mem, [1, 0, 2]) 824 attn_mem = array_ops.reshape( 825 attn_mem, 826 [self.hparams.num_children * self.num_groups, self.hparams.hidden_size]) 827 attn_mem = math_ops.matmul(attn_mem, attn_w_1) 828 attn_mem = array_ops.reshape( 829 attn_mem, 830 [self.hparams.num_children, self.num_groups, self.hparams.hidden_size]) 831 832 return last_c, last_h, attn_mem 833 834 def decode(self, 835 x, 836 last_c, 837 last_h, 838 attn_mem, 839 mode="target", 840 y=None): 841 """Decoder using LSTM. 842 843 Args: 844 x: tensor of size [num_children, num_groups, embedding_size]. 845 last_c: tensor of size [num_children, hidden_size], the final LSTM states 846 computed by self.encoder. 847 last_h: same as last_c. 848 attn_mem: tensor of size [num_children, num_groups, hidden_size]. 849 mode: "target" or "sample". 850 y: tensor of size [num_children, num_groups], the device placements. 851 852 Returns: 853 actions: tensor of size [num_children, num_groups], the placements of 854 devices 855 """ 856 with variable_scope.variable_scope(self.hparams.name, reuse=True): 857 w_lstm = variable_scope.get_variable("decoder_lstm") 858 forget_bias = variable_scope.get_variable("decoder_forget_bias") 859 device_embeddings = variable_scope.get_variable("device_embeddings") 860 device_softmax = variable_scope.get_variable("device_softmax") 861 device_go_embedding = variable_scope.get_variable("device_go_embedding") 862 attn_w_2 = variable_scope.get_variable("attn_w_2") 863 attn_v = variable_scope.get_variable("attn_v") 864 865 actions = tensor_array_ops.TensorArray( 866 dtypes.int32, 867 size=self.num_groups, 868 infer_shape=False, 869 clear_after_read=False) 870 871 # pylint: disable=unused-argument 872 def condition(i, *args): 873 return math_ops.less(i, self.num_groups) 874 875 # pylint: disable=missing-docstring 876 def body(i, prev_c, prev_h, actions, log_probs): 877 # pylint: disable=g-long-lambda 878 signal = control_flow_ops.cond( 879 math_ops.equal(i, 0), 880 lambda: array_ops.tile(device_go_embedding, 881 [self.hparams.num_children, 1]), 882 lambda: embedding_ops.embedding_lookup(device_embeddings, 883 actions.read(i - 1)) 884 ) 885 if self.hparams.keep_prob is not None: 886 signal = nn_ops.dropout(signal, self.hparams.keep_prob) 887 next_c, next_h = lstm(signal, prev_c, prev_h, w_lstm, forget_bias) 888 query = math_ops.matmul(next_h, attn_w_2) 889 query = array_ops.reshape( 890 query, [self.hparams.num_children, 1, self.hparams.hidden_size]) 891 query = math_ops.tanh(query + attn_mem) 892 query = array_ops.reshape(query, [ 893 self.hparams.num_children * self.num_groups, self.hparams.hidden_size 894 ]) 895 query = math_ops.matmul(query, attn_v) 896 query = array_ops.reshape(query, 897 [self.hparams.num_children, self.num_groups]) 898 query = nn_ops.softmax(query) 899 query = array_ops.reshape(query, 900 [self.hparams.num_children, self.num_groups, 1]) 901 query = math_ops.reduce_sum(attn_mem * query, axis=1) 902 query = array_ops.concat([next_h, query], axis=1) 903 logits = math_ops.matmul(query, device_softmax) 904 logits /= self.hparams.temperature 905 if self.hparams.tanh_constant > 0: 906 logits = math_ops.tanh(logits) * self.hparams.tanh_constant 907 if self.hparams.logits_std_noise > 0: 908 num_in_logits = math_ops.cast( 909 array_ops.size(logits), dtype=dtypes.float32) 910 avg_norm = math_ops.divide( 911 linalg_ops.norm(logits), math_ops.sqrt(num_in_logits)) 912 logits_noise = random_ops.random_normal( 913 array_ops.shape(logits), 914 stddev=self.hparams.logits_std_noise * avg_norm) 915 logits = control_flow_ops.cond( 916 self.global_step > self.hparams.stop_noise_step, lambda: logits, 917 lambda: logits + logits_noise) 918 919 if mode == "sample": 920 next_y = random_ops.multinomial(logits, 1, seed=self.hparams.seed) 921 elif mode == "greedy": 922 next_y = math_ops.argmax(logits, 1) 923 elif mode == "target": 924 next_y = array_ops.slice(y, [0, i], [-1, 1]) 925 else: 926 raise NotImplementedError 927 next_y = math_ops.cast(next_y, dtypes.int32) 928 next_y = array_ops.reshape(next_y, [self.hparams.num_children]) 929 actions = actions.write(i, next_y) 930 log_probs += nn_ops.sparse_softmax_cross_entropy_with_logits( 931 logits=logits, labels=next_y) 932 return i + 1, next_c, next_h, actions, log_probs 933 934 loop_vars = [ 935 constant_op.constant(0, dtype=dtypes.int32), last_c, last_h, actions, 936 array_ops.zeros([self.hparams.num_children], dtype=dtypes.float32) 937 ] 938 loop_outputs = control_flow_ops.while_loop(condition, body, loop_vars) 939 940 last_c = loop_outputs[-4] 941 last_h = loop_outputs[-3] 942 actions = loop_outputs[-2].stack() 943 actions = array_ops.transpose(actions, [1, 0]) 944 log_probs = loop_outputs[-1] 945 return actions, log_probs 946 947 def eval_placement(self, 948 sess, 949 child_id=0, 950 verbose=False): 951 grouping_actions, actions = sess.run([ 952 self.grouping_actions_cache, 953 self.actions_cache 954 ]) 955 grouping_actions = grouping_actions[child_id] 956 actions = actions[child_id] 957 if verbose: 958 global_step = sess.run(self.global_step) 959 if global_step % 100 == 0: 960 log_string = "op group assignments: " 961 for a in grouping_actions: 962 log_string += "{} ".format(a) 963 print(log_string[:-1]) 964 log_string = "group device assignments: " 965 for a in actions: 966 log_string += "{} ".format(a) 967 print(log_string[:-1]) 968 969 for op in self.important_ops: 970 topo_order_index = self.name_to_topo_order_index[op.name] 971 group_index = grouping_actions[topo_order_index] 972 op.device = self.devices[actions[group_index]].name 973 try: 974 _, run_time, _ = self.cluster.MeasureCosts(self.item) 975 except errors.ResourceExhaustedError: 976 run_time = self.hparams.failing_signal 977 return run_time 978 979 def update_reward(self, 980 sess, 981 run_time, 982 child_id=0, 983 verbose=False): 984 reward = self.compute_reward(run_time) 985 controller_ops = self.ops["controller"] 986 _, best_reward = sess.run( 987 [ 988 controller_ops["reward"]["update"][child_id], 989 controller_ops["best_reward"]["update"][child_id] 990 ], 991 feed_dict={ 992 controller_ops["reward"]["ph"][child_id]: reward, 993 }) 994 if verbose: 995 print(("run_time={:<.5f} reward={:<.5f} " 996 "best_reward={:<.5f}").format(run_time, reward, best_reward)) 997 998 # Reward is a double, best_reward a float: allow for some slack in the 999 # comparison. 1000 updated = abs(best_reward - reward) < 1e-6 1001 return updated 1002 1003 def generate_grouping(self, sess): 1004 controller_ops = self.ops["controller"] 1005 grouping_actions = sess.run(controller_ops["grouping_y_preds"]["sample"]) 1006 return grouping_actions 1007 1008 def generate_placement(self, grouping, sess): 1009 controller_ops = self.ops["controller"] 1010 feed_seq2seq_input_dict = {} 1011 feed_seq2seq_input_dict[self.seq2seq_input_layer] = grouping 1012 sess.run( 1013 controller_ops["y_preds"]["sample"], feed_dict=feed_seq2seq_input_dict) 1014 1015 def process_reward(self, sess): 1016 controller_ops = self.ops["controller"] 1017 run_ops = [ 1018 controller_ops["loss"], controller_ops["lr"], 1019 controller_ops["grad_norm"], controller_ops["grad_norms"], 1020 controller_ops["train_op"] 1021 ] 1022 sess.run(run_ops) 1023 sess.run(controller_ops["baseline_update"]) 1024 1025 def _get_train_ops(self, 1026 loss, 1027 tf_variables, 1028 global_step, 1029 grad_bound=1.25, 1030 lr_init=1e-3, 1031 lr_dec=0.9, 1032 start_decay_step=10000, 1033 decay_steps=100, 1034 optimizer_type="adam"): 1035 """Loss optimizer. 1036 1037 Args: 1038 loss: scalar tf tensor 1039 tf_variables: list of training variables, typically 1040 tf.trainable_variables() 1041 global_step: global_step 1042 grad_bound: max gradient norm 1043 lr_init: initial learning rate 1044 lr_dec: leaning rate decay coefficient 1045 start_decay_step: start decaying learning rate after this many steps 1046 decay_steps: apply decay rate factor at this step intervals 1047 optimizer_type: optimizer type should be either adam or sgd 1048 1049 Returns: 1050 train_op: training op 1051 learning_rate: scalar learning rate tensor 1052 grad_norm: l2 norm of the gradient vector 1053 all_grad_norms: l2 norm of each component 1054 """ 1055 lr_gstep = global_step - start_decay_step 1056 1057 def f1(): 1058 return constant_op.constant(lr_init) 1059 1060 def f2(): 1061 return learning_rate_decay.exponential_decay(lr_init, lr_gstep, 1062 decay_steps, lr_dec, True) 1063 1064 learning_rate = control_flow_ops.cond( 1065 math_ops.less(global_step, start_decay_step), 1066 f1, 1067 f2, 1068 name="learning_rate") 1069 1070 if optimizer_type == "adam": 1071 opt = adam.AdamOptimizer(learning_rate) 1072 elif optimizer_type == "sgd": 1073 opt = gradient_descent.GradientDescentOptimizer(learning_rate) 1074 grads_and_vars = opt.compute_gradients(loss, tf_variables) 1075 grad_norm = clip_ops.global_norm([g for g, v in grads_and_vars]) 1076 all_grad_norms = {} 1077 clipped_grads = [] 1078 clipped_rate = math_ops.maximum(grad_norm / grad_bound, 1.0) 1079 for g, v in grads_and_vars: 1080 if g is not None: 1081 if isinstance(g, tf_ops.IndexedSlices): 1082 clipped = g.values / clipped_rate 1083 norm_square = math_ops.reduce_sum(clipped * clipped) 1084 clipped = tf_ops.IndexedSlices(clipped, g.indices) 1085 else: 1086 clipped = g / clipped_rate 1087 norm_square = math_ops.reduce_sum(clipped * clipped) 1088 all_grad_norms[v.name] = math_ops.sqrt(norm_square) 1089 clipped_grads.append((clipped, v)) 1090 1091 train_op = opt.apply_gradients(clipped_grads, global_step) 1092 return train_op, learning_rate, grad_norm, all_grad_norms 1093 1094 1095 def lstm(x, prev_c, prev_h, w_lstm, forget_bias): 1096 """LSTM cell. 1097 1098 Args: 1099 x: tensors of size [num_children, hidden_size]. 1100 prev_c: tensors of size [num_children, hidden_size]. 1101 prev_h: same as prev_c. 1102 w_lstm: . 1103 forget_bias: . 1104 1105 Returns: 1106 next_c: 1107 next_h: 1108 """ 1109 ifog = math_ops.matmul(array_ops.concat([x, prev_h], axis=1), w_lstm) 1110 i, f, o, g = array_ops.split(ifog, 4, axis=1) 1111 i = math_ops.sigmoid(i) 1112 f = math_ops.sigmoid(f + forget_bias) 1113 o = math_ops.sigmoid(o) 1114 g = math_ops.tanh(g) 1115 next_c = i * g + f * prev_c 1116 next_h = o * math_ops.tanh(next_c) 1117 return next_c, next_h 1118