1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """The TF-GAN project provides a lightweight GAN training/testing framework. 16 17 This file contains the core helper functions to create and train a GAN model. 18 See the README or examples in `tensorflow_models` for details on how to use. 19 20 TF-GAN training occurs in four steps: 21 1) Create a model 22 2) Add a loss 23 3) Create train ops 24 4) Run the train ops 25 26 The functions in this file are organized around these four steps. Each function 27 corresponds to one of the steps. 28 """ 29 30 from __future__ import absolute_import 31 from __future__ import division 32 from __future__ import print_function 33 34 from tensorflow.contrib.framework.python.ops import variables as variables_lib 35 from tensorflow.contrib.gan.python import losses as tfgan_losses 36 from tensorflow.contrib.gan.python import namedtuples 37 from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl 38 from tensorflow.contrib.slim.python.slim import learning as slim_learning 39 from tensorflow.contrib.training.python.training import training 40 from tensorflow.python.framework import dtypes 41 from tensorflow.python.framework import ops 42 from tensorflow.python.ops import array_ops 43 from tensorflow.python.ops import check_ops 44 from tensorflow.python.ops import init_ops 45 from tensorflow.python.ops import math_ops 46 from tensorflow.python.ops import random_ops 47 from tensorflow.python.ops import variable_scope 48 from tensorflow.python.ops.losses import losses 49 from tensorflow.python.summary import summary 50 from tensorflow.python.training import session_run_hook 51 from tensorflow.python.training import sync_replicas_optimizer 52 from tensorflow.python.training import training_util 53 54 __all__ = [ 55 'gan_model', 56 'infogan_model', 57 'acgan_model', 58 'cyclegan_model', 59 'stargan_model', 60 'gan_loss', 61 'cyclegan_loss', 62 'stargan_loss', 63 'gan_train_ops', 64 'gan_train', 65 'get_sequential_train_hooks', 66 'get_joint_train_hooks', 67 'get_sequential_train_steps', 68 'RunTrainOpsHook', 69 ] 70 71 72 def gan_model( 73 # Lambdas defining models. 74 generator_fn, 75 discriminator_fn, 76 # Real data and conditioning. 77 real_data, 78 generator_inputs, 79 # Optional scopes. 80 generator_scope='Generator', 81 discriminator_scope='Discriminator', 82 # Options. 83 check_shapes=True): 84 """Returns GAN model outputs and variables. 85 86 Args: 87 generator_fn: A python lambda that takes `generator_inputs` as inputs and 88 returns the outputs of the GAN generator. 89 discriminator_fn: A python lambda that takes `real_data`/`generated data` 90 and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. 91 real_data: A Tensor representing the real data. 92 generator_inputs: A Tensor or list of Tensors to the generator. In the 93 vanilla GAN case, this might be a single noise Tensor. In the conditional 94 GAN case, this might be the generator's conditioning. 95 generator_scope: Optional generator variable scope. Useful if you want to 96 reuse a subgraph that has already been created. 97 discriminator_scope: Optional discriminator variable scope. Useful if you 98 want to reuse a subgraph that has already been created. 99 check_shapes: If `True`, check that generator produces Tensors that are the 100 same shape as real data. Otherwise, skip this check. 101 102 Returns: 103 A GANModel namedtuple. 104 105 Raises: 106 ValueError: If the generator outputs a Tensor that isn't the same shape as 107 `real_data`. 108 """ 109 # Create models 110 with variable_scope.variable_scope(generator_scope) as gen_scope: 111 generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) 112 generated_data = generator_fn(generator_inputs) 113 with variable_scope.variable_scope(discriminator_scope) as dis_scope: 114 discriminator_gen_outputs = discriminator_fn(generated_data, 115 generator_inputs) 116 with variable_scope.variable_scope(dis_scope, reuse=True): 117 real_data = _convert_tensor_or_l_or_d(real_data) 118 discriminator_real_outputs = discriminator_fn(real_data, generator_inputs) 119 120 if check_shapes: 121 if not generated_data.shape.is_compatible_with(real_data.shape): 122 raise ValueError( 123 'Generator output shape (%s) must be the same shape as real data ' 124 '(%s).' % (generated_data.shape, real_data.shape)) 125 126 # Get model-specific variables. 127 generator_variables = variables_lib.get_trainable_variables(gen_scope) 128 discriminator_variables = variables_lib.get_trainable_variables(dis_scope) 129 130 return namedtuples.GANModel( 131 generator_inputs, generated_data, generator_variables, gen_scope, 132 generator_fn, real_data, discriminator_real_outputs, 133 discriminator_gen_outputs, discriminator_variables, dis_scope, 134 discriminator_fn) 135 136 137 def infogan_model( 138 # Lambdas defining models. 139 generator_fn, 140 discriminator_fn, 141 # Real data and conditioning. 142 real_data, 143 unstructured_generator_inputs, 144 structured_generator_inputs, 145 # Optional scopes. 146 generator_scope='Generator', 147 discriminator_scope='Discriminator'): 148 """Returns an InfoGAN model outputs and variables. 149 150 See https://arxiv.org/abs/1606.03657 for more details. 151 152 Args: 153 generator_fn: A python lambda that takes a list of Tensors as inputs and 154 returns the outputs of the GAN generator. 155 discriminator_fn: A python lambda that takes `real_data`/`generated data` 156 and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list). 157 `logits` are in the range [-inf, inf], and `distribution_list` is a list 158 of Tensorflow distributions representing the predicted noise distribution 159 of the ith structure noise. 160 real_data: A Tensor representing the real data. 161 unstructured_generator_inputs: A list of Tensors to the generator. 162 These tensors represent the unstructured noise or conditioning. 163 structured_generator_inputs: A list of Tensors to the generator. 164 These tensors must have high mutual information with the recognizer. 165 generator_scope: Optional generator variable scope. Useful if you want to 166 reuse a subgraph that has already been created. 167 discriminator_scope: Optional discriminator variable scope. Useful if you 168 want to reuse a subgraph that has already been created. 169 170 Returns: 171 An InfoGANModel namedtuple. 172 173 Raises: 174 ValueError: If the generator outputs a Tensor that isn't the same shape as 175 `real_data`. 176 ValueError: If the discriminator output is malformed. 177 """ 178 # Create models 179 with variable_scope.variable_scope(generator_scope) as gen_scope: 180 unstructured_generator_inputs = _convert_tensor_or_l_or_d( 181 unstructured_generator_inputs) 182 structured_generator_inputs = _convert_tensor_or_l_or_d( 183 structured_generator_inputs) 184 generator_inputs = ( 185 unstructured_generator_inputs + structured_generator_inputs) 186 generated_data = generator_fn(generator_inputs) 187 with variable_scope.variable_scope(discriminator_scope) as disc_scope: 188 dis_gen_outputs, predicted_distributions = discriminator_fn( 189 generated_data, generator_inputs) 190 _validate_distributions(predicted_distributions, structured_generator_inputs) 191 with variable_scope.variable_scope(disc_scope, reuse=True): 192 real_data = ops.convert_to_tensor(real_data) 193 dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs) 194 195 if not generated_data.get_shape().is_compatible_with(real_data.get_shape()): 196 raise ValueError( 197 'Generator output shape (%s) must be the same shape as real data ' 198 '(%s).' % (generated_data.get_shape(), real_data.get_shape())) 199 200 # Get model-specific variables. 201 generator_variables = variables_lib.get_trainable_variables(gen_scope) 202 discriminator_variables = variables_lib.get_trainable_variables(disc_scope) 203 204 return namedtuples.InfoGANModel( 205 generator_inputs, 206 generated_data, 207 generator_variables, 208 gen_scope, 209 generator_fn, 210 real_data, 211 dis_real_outputs, 212 dis_gen_outputs, 213 discriminator_variables, 214 disc_scope, 215 lambda x, y: discriminator_fn(x, y)[0], # conform to non-InfoGAN API 216 structured_generator_inputs, 217 predicted_distributions, 218 discriminator_fn) 219 220 221 def acgan_model( 222 # Lambdas defining models. 223 generator_fn, 224 discriminator_fn, 225 # Real data and conditioning. 226 real_data, 227 generator_inputs, 228 one_hot_labels, 229 # Optional scopes. 230 generator_scope='Generator', 231 discriminator_scope='Discriminator', 232 # Options. 233 check_shapes=True): 234 """Returns an ACGANModel contains all the pieces needed for ACGAN training. 235 236 The `acgan_model` is the same as the `gan_model` with the only difference 237 being that the discriminator additionally outputs logits to classify the input 238 (real or generated). 239 Therefore, an explicit field holding one_hot_labels is necessary, as well as a 240 discriminator_fn that outputs a 2-tuple holding the logits for real/fake and 241 classification. 242 243 See https://arxiv.org/abs/1610.09585 for more details. 244 245 Args: 246 generator_fn: A python lambda that takes `generator_inputs` as inputs and 247 returns the outputs of the GAN generator. 248 discriminator_fn: A python lambda that takes `real_data`/`generated data` 249 and `generator_inputs`. Outputs a tuple consisting of two Tensors: 250 (1) real/fake logits in the range [-inf, inf] 251 (2) classification logits in the range [-inf, inf] 252 real_data: A Tensor representing the real data. 253 generator_inputs: A Tensor or list of Tensors to the generator. In the 254 vanilla GAN case, this might be a single noise Tensor. In the conditional 255 GAN case, this might be the generator's conditioning. 256 one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by 257 acgan_loss. 258 generator_scope: Optional generator variable scope. Useful if you want to 259 reuse a subgraph that has already been created. 260 discriminator_scope: Optional discriminator variable scope. Useful if you 261 want to reuse a subgraph that has already been created. 262 check_shapes: If `True`, check that generator produces Tensors that are the 263 same shape as real data. Otherwise, skip this check. 264 265 Returns: 266 A ACGANModel namedtuple. 267 268 Raises: 269 ValueError: If the generator outputs a Tensor that isn't the same shape as 270 `real_data`. 271 TypeError: If the discriminator does not output a tuple consisting of 272 (discrimination logits, classification logits). 273 """ 274 # Create models 275 with variable_scope.variable_scope(generator_scope) as gen_scope: 276 generator_inputs = _convert_tensor_or_l_or_d(generator_inputs) 277 generated_data = generator_fn(generator_inputs) 278 with variable_scope.variable_scope(discriminator_scope) as dis_scope: 279 with ops.name_scope(dis_scope.name + '/generated/'): 280 (discriminator_gen_outputs, discriminator_gen_classification_logits 281 ) = _validate_acgan_discriminator_outputs( 282 discriminator_fn(generated_data, generator_inputs)) 283 with variable_scope.variable_scope(dis_scope, reuse=True): 284 with ops.name_scope(dis_scope.name + '/real/'): 285 real_data = ops.convert_to_tensor(real_data) 286 (discriminator_real_outputs, discriminator_real_classification_logits 287 ) = _validate_acgan_discriminator_outputs( 288 discriminator_fn(real_data, generator_inputs)) 289 if check_shapes: 290 if not generated_data.shape.is_compatible_with(real_data.shape): 291 raise ValueError( 292 'Generator output shape (%s) must be the same shape as real data ' 293 '(%s).' % (generated_data.shape, real_data.shape)) 294 295 # Get model-specific variables. 296 generator_variables = variables_lib.get_trainable_variables(gen_scope) 297 discriminator_variables = variables_lib.get_trainable_variables(dis_scope) 298 299 return namedtuples.ACGANModel( 300 generator_inputs, generated_data, generator_variables, gen_scope, 301 generator_fn, real_data, discriminator_real_outputs, 302 discriminator_gen_outputs, discriminator_variables, dis_scope, 303 discriminator_fn, one_hot_labels, 304 discriminator_real_classification_logits, 305 discriminator_gen_classification_logits) 306 307 308 def cyclegan_model( 309 # Lambdas defining models. 310 generator_fn, 311 discriminator_fn, 312 # data X and Y. 313 data_x, 314 data_y, 315 # Optional scopes. 316 generator_scope='Generator', 317 discriminator_scope='Discriminator', 318 model_x2y_scope='ModelX2Y', 319 model_y2x_scope='ModelY2X', 320 # Options. 321 check_shapes=True): 322 """Returns a CycleGAN model outputs and variables. 323 324 See https://arxiv.org/abs/1703.10593 for more details. 325 326 Args: 327 generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and 328 returns the outputs of the GAN generator. 329 discriminator_fn: A python lambda that takes `real_data`/`generated data` 330 and `generator_inputs`. Outputs a Tensor in the range [-inf, inf]. 331 data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`. 332 data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`. 333 generator_scope: Optional generator variable scope. Useful if you want to 334 reuse a subgraph that has already been created. Defaults to 'Generator'. 335 discriminator_scope: Optional discriminator variable scope. Useful if you 336 want to reuse a subgraph that has already been created. Defaults to 337 'Discriminator'. 338 model_x2y_scope: Optional variable scope for model x2y variables. Defaults 339 to 'ModelX2Y'. 340 model_y2x_scope: Optional variable scope for model y2x variables. Defaults 341 to 'ModelY2X'. 342 check_shapes: If `True`, check that generator produces Tensors that are the 343 same shape as `data_x` (`data_y`). Otherwise, skip this check. 344 345 Returns: 346 A `CycleGANModel` namedtuple. 347 348 Raises: 349 ValueError: If `check_shapes` is True and `data_x` or the generator output 350 does not have the same shape as `data_y`. 351 """ 352 353 # Create models. 354 def _define_partial_model(input_data, output_data): 355 return gan_model( 356 generator_fn=generator_fn, 357 discriminator_fn=discriminator_fn, 358 real_data=output_data, 359 generator_inputs=input_data, 360 generator_scope=generator_scope, 361 discriminator_scope=discriminator_scope, 362 check_shapes=check_shapes) 363 364 with variable_scope.variable_scope(model_x2y_scope): 365 model_x2y = _define_partial_model(data_x, data_y) 366 with variable_scope.variable_scope(model_y2x_scope): 367 model_y2x = _define_partial_model(data_y, data_x) 368 369 with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True): 370 reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data) 371 with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True): 372 reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data) 373 374 return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x, 375 reconstructed_y) 376 377 378 def stargan_model(generator_fn, 379 discriminator_fn, 380 input_data, 381 input_data_domain_label, 382 generator_scope='Generator', 383 discriminator_scope='Discriminator'): 384 """Returns a StarGAN model outputs and variables. 385 386 See https://arxiv.org/abs/1711.09020 for more details. 387 388 Args: 389 generator_fn: A python lambda that takes `inputs` and `targets` as inputs 390 and returns 'generated_data' as the transformed version of `input` based 391 on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n, 392 num_domains), and `generated_data` has the same shape as `input`. 393 discriminator_fn: A python lambda that takes `inputs` and `num_domains` as 394 inputs and returns a tuple (`source_prediction`, `domain_prediction`). 395 `source_prediction` represents the source(real/generated) prediction by 396 the discriminator, and `domain_prediction` represents the domain 397 prediction/classification by the discriminator. `source_prediction` has 398 shape (n) and `domain_prediction` has shape (n, num_domains). 399 input_data: Tensor or a list of tensor of shape (n, h, w, c) representing 400 the real input images. 401 input_data_domain_label: Tensor or a list of tensor of shape (batch_size, 402 num_domains) representing the domain label associated with the real 403 images. 404 generator_scope: Optional generator variable scope. Useful if you want to 405 reuse a subgraph that has already been created. 406 discriminator_scope: Optional discriminator variable scope. Useful if you 407 want to reuse a subgraph that has already been created. 408 409 Returns: 410 StarGANModel nametuple return the tensor that are needed to compute the 411 loss. 412 413 Raises: 414 ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully 415 defined in every dimensions. 416 """ 417 418 # Convert to tensor. 419 input_data = _convert_tensor_or_l_or_d(input_data) 420 input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label) 421 422 # Convert list of tensor to a single tensor if applicable. 423 if isinstance(input_data, (list, tuple)): 424 input_data = array_ops.concat( 425 [ops.convert_to_tensor(x) for x in input_data], 0) 426 if isinstance(input_data_domain_label, (list, tuple)): 427 input_data_domain_label = array_ops.concat( 428 [ops.convert_to_tensor(x) for x in input_data_domain_label], 0) 429 430 # Get batch_size, num_domains from the labels. 431 input_data_domain_label.shape.assert_has_rank(2) 432 input_data_domain_label.shape.assert_is_fully_defined() 433 batch_size, num_domains = input_data_domain_label.shape.as_list() 434 435 # Transform input_data to random target domains. 436 with variable_scope.variable_scope(generator_scope) as generator_scope: 437 generated_data_domain_target = _generate_stargan_random_domain_target( 438 batch_size, num_domains) 439 generated_data = generator_fn(input_data, generated_data_domain_target) 440 441 # Transform generated_data back to the original input_data domain. 442 with variable_scope.variable_scope(generator_scope, reuse=True): 443 reconstructed_data = generator_fn(generated_data, input_data_domain_label) 444 445 # Predict source and domain for the generated_data using the discriminator. 446 with variable_scope.variable_scope( 447 discriminator_scope) as discriminator_scope: 448 disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn( 449 generated_data, num_domains) 450 451 # Predict source and domain for the input_data using the discriminator. 452 with variable_scope.variable_scope(discriminator_scope, reuse=True): 453 disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn( 454 input_data, num_domains) 455 456 # Collect trainable variables from the neural networks. 457 generator_variables = variables_lib.get_trainable_variables(generator_scope) 458 discriminator_variables = variables_lib.get_trainable_variables( 459 discriminator_scope) 460 461 # Create the StarGANModel namedtuple. 462 return namedtuples.StarGANModel( 463 input_data=input_data, 464 input_data_domain_label=input_data_domain_label, 465 generated_data=generated_data, 466 generated_data_domain_target=generated_data_domain_target, 467 reconstructed_data=reconstructed_data, 468 discriminator_input_data_source_predication=disc_input_data_source_pred, 469 discriminator_generated_data_source_predication=disc_gen_data_source_pred, 470 discriminator_input_data_domain_predication=disc_input_data_domain_pred, 471 discriminator_generated_data_domain_predication=disc_gen_data_domain_pred, 472 generator_variables=generator_variables, 473 generator_scope=generator_scope, 474 generator_fn=generator_fn, 475 discriminator_variables=discriminator_variables, 476 discriminator_scope=discriminator_scope, 477 discriminator_fn=discriminator_fn) 478 479 480 def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'): 481 if isinstance(aux_loss_weight, ops.Tensor): 482 aux_loss_weight.shape.assert_is_compatible_with([]) 483 with ops.control_dependencies( 484 [check_ops.assert_greater_equal(aux_loss_weight, 0.0)]): 485 aux_loss_weight = array_ops.identity(aux_loss_weight) 486 elif aux_loss_weight is not None and aux_loss_weight < 0: 487 raise ValueError('`%s` must be greater than 0. Instead, was %s' % 488 (name, aux_loss_weight)) 489 return aux_loss_weight 490 491 492 def _use_aux_loss(aux_loss_weight): 493 if aux_loss_weight is not None: 494 if not isinstance(aux_loss_weight, ops.Tensor): 495 return aux_loss_weight > 0 496 else: 497 return True 498 else: 499 return False 500 501 502 def _tensor_pool_adjusted_model(model, tensor_pool_fn): 503 """Adjusts model using `tensor_pool_fn`. 504 505 Args: 506 model: A GANModel tuple. 507 tensor_pool_fn: A function that takes (generated_data, generator_inputs), 508 stores them in an internal pool and returns a previously stored 509 (generated_data, generator_inputs) with some probability. For example 510 tfgan.features.tensor_pool. 511 512 Returns: 513 A new GANModel tuple where discriminator outputs are adjusted by taking 514 pooled generator outputs as inputs. Returns the original model if 515 `tensor_pool_fn` is None. 516 517 Raises: 518 ValueError: If tensor pool does not support the `model`. 519 """ 520 if isinstance(model, namedtuples.GANModel): 521 pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( 522 (model.generator_inputs, model.generated_data)) 523 with variable_scope.variable_scope(model.discriminator_scope, reuse=True): 524 dis_gen_outputs = model.discriminator_fn(pooled_generated_data, 525 pooled_generator_inputs) 526 return model._replace( 527 generator_inputs=pooled_generator_inputs, 528 generated_data=pooled_generated_data, 529 discriminator_gen_outputs=dis_gen_outputs) 530 elif isinstance(model, namedtuples.ACGANModel): 531 pooled_generator_inputs, pooled_generated_data = tensor_pool_fn( 532 (model.generator_inputs, model.generated_data)) 533 with variable_scope.variable_scope(model.discriminator_scope, reuse=True): 534 (pooled_discriminator_gen_outputs, 535 pooled_discriminator_gen_classification_logits) = model.discriminator_fn( 536 pooled_generated_data, pooled_generator_inputs) 537 return model._replace( 538 generator_inputs=pooled_generator_inputs, 539 generated_data=pooled_generated_data, 540 discriminator_gen_outputs=pooled_discriminator_gen_outputs, 541 discriminator_gen_classification_logits= 542 pooled_discriminator_gen_classification_logits) 543 elif isinstance(model, namedtuples.InfoGANModel): 544 pooled_generator_inputs, pooled_generated_data, pooled_structured_input = ( 545 tensor_pool_fn((model.generator_inputs, model.generated_data, 546 model.structured_generator_inputs))) 547 with variable_scope.variable_scope(model.discriminator_scope, reuse=True): 548 (pooled_discriminator_gen_outputs, 549 pooled_predicted_distributions) = model.discriminator_and_aux_fn( 550 pooled_generated_data, pooled_generator_inputs) 551 return model._replace( 552 generator_inputs=pooled_generator_inputs, 553 generated_data=pooled_generated_data, 554 structured_generator_inputs=pooled_structured_input, 555 discriminator_gen_outputs=pooled_discriminator_gen_outputs, 556 predicted_distributions=pooled_predicted_distributions) 557 else: 558 raise ValueError('Tensor pool does not support `model`: %s.' % type(model)) 559 560 561 def gan_loss( 562 # GANModel. 563 model, 564 # Loss functions. 565 generator_loss_fn=tfgan_losses.wasserstein_generator_loss, 566 discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss, 567 # Auxiliary losses. 568 gradient_penalty_weight=None, 569 gradient_penalty_epsilon=1e-10, 570 gradient_penalty_target=1.0, 571 gradient_penalty_one_sided=False, 572 mutual_information_penalty_weight=None, 573 aux_cond_generator_weight=None, 574 aux_cond_discriminator_weight=None, 575 tensor_pool_fn=None, 576 # Options. 577 add_summaries=True): 578 """Returns losses necessary to train generator and discriminator. 579 580 Args: 581 model: A GANModel tuple. 582 generator_loss_fn: The loss function on the generator. Takes a GANModel 583 tuple. 584 discriminator_loss_fn: The loss function on the discriminator. Takes a 585 GANModel tuple. 586 gradient_penalty_weight: If not `None`, must be a non-negative Python number 587 or Tensor indicating how much to weight the gradient penalty. See 588 https://arxiv.org/pdf/1704.00028.pdf for more details. 589 gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the 590 small positive value used by the gradient penalty function for numerical 591 stability. Note some applications will need to increase this value to 592 avoid NaNs. 593 gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python 594 number or `Tensor` indicating the target value of gradient norm. See the 595 CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0. 596 gradient_penalty_one_sided: If `True`, penalty proposed in 597 https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. 598 mutual_information_penalty_weight: If not `None`, must be a non-negative 599 Python number or Tensor indicating how much to weight the mutual 600 information penalty. See https://arxiv.org/abs/1606.03657 for more 601 details. 602 aux_cond_generator_weight: If not None: add a classification loss as in 603 https://arxiv.org/abs/1610.09585 604 aux_cond_discriminator_weight: If not None: add a classification loss as in 605 https://arxiv.org/abs/1610.09585 606 tensor_pool_fn: A function that takes (generated_data, generator_inputs), 607 stores them in an internal pool and returns previous stored 608 (generated_data, generator_inputs). For example 609 `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool). 610 add_summaries: Whether or not to add summaries for the losses. 611 612 Returns: 613 A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes 614 regularization losses. 615 616 Raises: 617 ValueError: If any of the auxiliary loss weights is provided and negative. 618 ValueError: If `mutual_information_penalty_weight` is provided, but the 619 `model` isn't an `InfoGANModel`. 620 """ 621 # Validate arguments. 622 gradient_penalty_weight = _validate_aux_loss_weight( 623 gradient_penalty_weight, 'gradient_penalty_weight') 624 mutual_information_penalty_weight = _validate_aux_loss_weight( 625 mutual_information_penalty_weight, 'infogan_weight') 626 aux_cond_generator_weight = _validate_aux_loss_weight( 627 aux_cond_generator_weight, 'aux_cond_generator_weight') 628 aux_cond_discriminator_weight = _validate_aux_loss_weight( 629 aux_cond_discriminator_weight, 'aux_cond_discriminator_weight') 630 631 # Verify configuration for mutual information penalty 632 if (_use_aux_loss(mutual_information_penalty_weight) and 633 not isinstance(model, namedtuples.InfoGANModel)): 634 raise ValueError( 635 'When `mutual_information_penalty_weight` is provided, `model` must be ' 636 'an `InfoGANModel`. Instead, was %s.' % type(model)) 637 638 # Verify configuration for mutual auxiliary condition loss (ACGAN). 639 if ((_use_aux_loss(aux_cond_generator_weight) or 640 _use_aux_loss(aux_cond_discriminator_weight)) and 641 not isinstance(model, namedtuples.ACGANModel)): 642 raise ValueError( 643 'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` ' 644 'is provided, `model` must be an `ACGANModel`. Instead, was %s.' % 645 type(model)) 646 647 # Optionally create pooled model. 648 if tensor_pool_fn: 649 pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn) 650 else: 651 pooled_model = model 652 653 # Create standard losses. 654 gen_loss = generator_loss_fn(model, add_summaries=add_summaries) 655 dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries) 656 657 # Add optional extra losses. 658 if _use_aux_loss(gradient_penalty_weight): 659 gp_loss = tfgan_losses.wasserstein_gradient_penalty( 660 pooled_model, 661 epsilon=gradient_penalty_epsilon, 662 target=gradient_penalty_target, 663 one_sided=gradient_penalty_one_sided, 664 add_summaries=add_summaries) 665 dis_loss += gradient_penalty_weight * gp_loss 666 if _use_aux_loss(mutual_information_penalty_weight): 667 gen_info_loss = tfgan_losses.mutual_information_penalty( 668 model, add_summaries=add_summaries) 669 if tensor_pool_fn is None: 670 dis_info_loss = gen_info_loss 671 else: 672 dis_info_loss = tfgan_losses.mutual_information_penalty( 673 pooled_model, add_summaries=add_summaries) 674 gen_loss += mutual_information_penalty_weight * gen_info_loss 675 dis_loss += mutual_information_penalty_weight * dis_info_loss 676 if _use_aux_loss(aux_cond_generator_weight): 677 ac_gen_loss = tfgan_losses.acgan_generator_loss( 678 model, add_summaries=add_summaries) 679 gen_loss += aux_cond_generator_weight * ac_gen_loss 680 if _use_aux_loss(aux_cond_discriminator_weight): 681 ac_disc_loss = tfgan_losses.acgan_discriminator_loss( 682 pooled_model, add_summaries=add_summaries) 683 dis_loss += aux_cond_discriminator_weight * ac_disc_loss 684 # Gathers auxiliary losses. 685 if model.generator_scope: 686 gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name) 687 else: 688 gen_reg_loss = 0 689 if model.discriminator_scope: 690 dis_reg_loss = losses.get_regularization_loss( 691 model.discriminator_scope.name) 692 else: 693 dis_reg_loss = 0 694 695 return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss) 696 697 698 def cyclegan_loss( 699 model, 700 # Loss functions. 701 generator_loss_fn=tfgan_losses.least_squares_generator_loss, 702 discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss, 703 # Auxiliary losses. 704 cycle_consistency_loss_fn=tfgan_losses.cycle_consistency_loss, 705 cycle_consistency_loss_weight=10.0, 706 # Options 707 **kwargs): 708 """Returns the losses for a `CycleGANModel`. 709 710 See https://arxiv.org/abs/1703.10593 for more details. 711 712 Args: 713 model: A `CycleGANModel` namedtuple. 714 generator_loss_fn: The loss function on the generator. Takes a `GANModel` 715 named tuple. 716 discriminator_loss_fn: The loss function on the discriminator. Takes a 717 `GANModel` namedtuple. 718 cycle_consistency_loss_fn: The cycle consistency loss function. Takes a 719 `CycleGANModel` namedtuple. 720 cycle_consistency_loss_weight: A non-negative Python number or a scalar 721 `Tensor` indicating how much to weigh the cycle consistency loss. 722 **kwargs: Keyword args to pass directly to `gan_loss` to construct the loss 723 for each partial model of `model`. 724 725 Returns: 726 A `CycleGANLoss` namedtuple. 727 728 Raises: 729 ValueError: If `model` is not a `CycleGANModel` namedtuple. 730 """ 731 # Sanity checks. 732 if not isinstance(model, namedtuples.CycleGANModel): 733 raise ValueError( 734 '`model` must be a `CycleGANModel`. Instead, was %s.' % type(model)) 735 736 # Defines cycle consistency loss. 737 cycle_consistency_loss = cycle_consistency_loss_fn( 738 model, add_summaries=kwargs.get('add_summaries', True)) 739 cycle_consistency_loss_weight = _validate_aux_loss_weight( 740 cycle_consistency_loss_weight, 'cycle_consistency_loss_weight') 741 aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss 742 743 # Defines losses for each partial model. 744 def _partial_loss(partial_model): 745 partial_loss = gan_loss( 746 partial_model, 747 generator_loss_fn=generator_loss_fn, 748 discriminator_loss_fn=discriminator_loss_fn, 749 **kwargs) 750 return partial_loss._replace(generator_loss=partial_loss.generator_loss + 751 aux_loss) 752 753 with ops.name_scope('cyclegan_loss_x2y'): 754 loss_x2y = _partial_loss(model.model_x2y) 755 with ops.name_scope('cyclegan_loss_y2x'): 756 loss_y2x = _partial_loss(model.model_y2x) 757 758 return namedtuples.CycleGANLoss(loss_x2y, loss_y2x) 759 760 # Begin google-internal 761 # The four major parts can be found here: http://screen/tMRMBAohDYG. 762 # End google-internal 763 def stargan_loss( 764 model, 765 generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper( 766 tfgan_losses_impl.wasserstein_generator_loss), 767 discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper( 768 tfgan_losses_impl.wasserstein_discriminator_loss), 769 gradient_penalty_weight=10.0, 770 gradient_penalty_epsilon=1e-10, 771 gradient_penalty_target=1.0, 772 gradient_penalty_one_sided=False, 773 reconstruction_loss_fn=losses.absolute_difference, 774 reconstruction_loss_weight=10.0, 775 classification_loss_fn=losses.softmax_cross_entropy, 776 classification_loss_weight=1.0, 777 classification_one_hot=True, 778 add_summaries=True): 779 """StarGAN Loss. 780 781 Args: 782 model: (StarGAN) Model output of the stargan_model() function call. 783 generator_loss_fn: The loss function on the generator. Takes a 784 `StarGANModel` named tuple. 785 discriminator_loss_fn: The loss function on the discriminator. Takes a 786 `StarGANModel` namedtuple. 787 gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per 788 the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to 789 turn off gradient penalty. 790 gradient_penalty_epsilon: (float) A small positive number added for 791 numerical stability when computing the gradient norm. 792 gradient_penalty_target: (float, or tf.float `Tensor`) The target value of 793 gradient norm. Defaults to 1.0. 794 gradient_penalty_one_sided: (bool) If `True`, penalty proposed in 795 https://arxiv.org/abs/1709.08894 is used. Defaults to `False`. 796 reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm 797 and the function must conform to the `tf.losses` API. 798 reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0. 799 classification_loss_fn: The loss function on the discriminator's ability to 800 classify domain of the input. Default to one-hot softmax cross entropy 801 loss, and the function must conform to the `tf.losses` API. 802 classification_loss_weight: (float) Classification loss weight. Default to 803 1.0. 804 classification_one_hot: (bool) If the label is one hot representation. 805 Default to True. If False, classification classification_loss_fn need to 806 be sigmoid cross entropy loss instead. 807 add_summaries: (bool) Add the loss to the summary 808 809 Returns: 810 GANLoss namedtuple where we have generator loss and discriminator loss. 811 812 Raises: 813 ValueError: If input StarGANModel.input_data_domain_label does not have rank 814 2, or dimension 2 is not defined. 815 """ 816 817 def _classification_loss_helper(true_labels, predict_logits, scope_name): 818 """Classification Loss Function Helper. 819 820 Args: 821 true_labels: Tensor of shape [batch_size, num_domains] representing the 822 label where each row is an one-hot vector. 823 predict_logits: Tensor of shape [batch_size, num_domains] representing the 824 predicted label logit, which is UNSCALED output from the NN. 825 scope_name: (string) Name scope of the loss component. 826 827 Returns: 828 Single scalar tensor representing the classification loss. 829 """ 830 831 with ops.name_scope(scope_name, values=(true_labels, predict_logits)): 832 833 loss = classification_loss_fn( 834 onehot_labels=true_labels, logits=predict_logits) 835 836 if not classification_one_hot: 837 loss = math_ops.reduce_sum(loss, axis=1) 838 loss = math_ops.reduce_mean(loss) 839 840 if add_summaries: 841 summary.scalar(scope_name, loss) 842 843 return loss 844 845 # Check input shape. 846 model.input_data_domain_label.shape.assert_has_rank(2) 847 model.input_data_domain_label.shape[1:].assert_is_fully_defined() 848 849 # Adversarial Loss. 850 generator_loss = generator_loss_fn(model, add_summaries=add_summaries) 851 discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries) 852 853 # Gradient Penalty. 854 if _use_aux_loss(gradient_penalty_weight): 855 gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper( 856 tfgan_losses_impl.wasserstein_gradient_penalty) 857 discriminator_loss += gradient_penalty_fn( 858 model, 859 epsilon=gradient_penalty_epsilon, 860 target=gradient_penalty_target, 861 one_sided=gradient_penalty_one_sided, 862 add_summaries=add_summaries) * gradient_penalty_weight 863 864 # Reconstruction Loss. 865 reconstruction_loss = reconstruction_loss_fn(model.input_data, 866 model.reconstructed_data) 867 generator_loss += reconstruction_loss * reconstruction_loss_weight 868 if add_summaries: 869 summary.scalar('reconstruction_loss', reconstruction_loss) 870 871 # Classification Loss. 872 generator_loss += _classification_loss_helper( 873 true_labels=model.generated_data_domain_target, 874 predict_logits=model.discriminator_generated_data_domain_predication, 875 scope_name='generator_classification_loss') * classification_loss_weight 876 discriminator_loss += _classification_loss_helper( 877 true_labels=model.input_data_domain_label, 878 predict_logits=model.discriminator_input_data_domain_predication, 879 scope_name='discriminator_classification_loss' 880 ) * classification_loss_weight 881 882 return namedtuples.GANLoss(generator_loss, discriminator_loss) 883 884 885 def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True): 886 """Gets generator and discriminator update ops. 887 888 Args: 889 kwargs: A dictionary of kwargs to be passed to `create_train_op`. 890 `update_ops` is removed, if present. 891 gen_scope: A scope for the generator. 892 dis_scope: A scope for the discriminator. 893 check_for_unused_ops: A Python bool. If `True`, throw Exception if there are 894 unused update ops. 895 896 Returns: 897 A 2-tuple of (generator update ops, discriminator train ops). 898 899 Raises: 900 ValueError: If there are update ops outside of the generator or 901 discriminator scopes. 902 """ 903 if 'update_ops' in kwargs: 904 update_ops = set(kwargs['update_ops']) 905 del kwargs['update_ops'] 906 else: 907 update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS)) 908 909 all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope)) 910 all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope)) 911 912 if check_for_unused_ops: 913 unused_ops = update_ops - all_gen_ops - all_dis_ops 914 if unused_ops: 915 raise ValueError('There are unused update ops: %s' % unused_ops) 916 917 gen_update_ops = list(all_gen_ops & update_ops) 918 dis_update_ops = list(all_dis_ops & update_ops) 919 920 return gen_update_ops, dis_update_ops 921 922 923 def gan_train_ops( 924 model, 925 loss, 926 generator_optimizer, 927 discriminator_optimizer, 928 check_for_unused_update_ops=True, 929 is_chief=True, 930 # Optional args to pass directly to the `create_train_op`. 931 **kwargs): 932 """Returns GAN train ops. 933 934 The highest-level call in TF-GAN. It is composed of functions that can also 935 be called, should a user require more control over some part of the GAN 936 training process. 937 938 Args: 939 model: A GANModel. 940 loss: A GANLoss. 941 generator_optimizer: The optimizer for generator updates. 942 discriminator_optimizer: The optimizer for the discriminator updates. 943 check_for_unused_update_ops: If `True`, throws an exception if there are 944 update ops outside of the generator or discriminator scopes. 945 is_chief: Specifies whether or not the training is being run by the primary 946 replica during replica training. 947 **kwargs: Keyword args to pass directly to 948 `training.create_train_op` for both the generator and 949 discriminator train op. 950 951 Returns: 952 A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can 953 be used to train a generator/discriminator pair. 954 """ 955 if isinstance(model, namedtuples.CycleGANModel): 956 # Get and store all arguments other than model and loss from locals. 957 # Contents of locals should not be modified, may not affect values. So make 958 # a copy. https://docs.python.org/2/library/functions.html#locals. 959 saved_params = dict(locals()) 960 saved_params.pop('model', None) 961 saved_params.pop('loss', None) 962 kwargs = saved_params.pop('kwargs', {}) 963 saved_params.update(kwargs) 964 with ops.name_scope('cyclegan_x2y_train'): 965 train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y, 966 **saved_params) 967 with ops.name_scope('cyclegan_y2x_train'): 968 train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x, 969 **saved_params) 970 return namedtuples.GANTrainOps( 971 (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op), 972 (train_ops_x2y.discriminator_train_op, 973 train_ops_y2x.discriminator_train_op), 974 training_util.get_or_create_global_step().assign_add(1)) 975 976 # Create global step increment op. 977 global_step = training_util.get_or_create_global_step() 978 global_step_inc = global_step.assign_add(1) 979 980 # Get generator and discriminator update ops. We split them so that update 981 # ops aren't accidentally run multiple times. For now, throw an error if 982 # there are update ops that aren't associated with either the generator or 983 # the discriminator. Might modify the `kwargs` dictionary. 984 gen_update_ops, dis_update_ops = _get_update_ops( 985 kwargs, model.generator_scope.name, model.discriminator_scope.name, 986 check_for_unused_update_ops) 987 988 # Get the sync hooks if these are needed. 989 sync_hooks = [] 990 991 generator_global_step = None 992 if isinstance(generator_optimizer, 993 sync_replicas_optimizer.SyncReplicasOptimizer): 994 # TODO(joelshor): Figure out a way to get this work without including the 995 # dummy global step in the checkpoint. 996 # WARNING: Making this variable a local variable causes sync replicas to 997 # hang forever. 998 generator_global_step = variable_scope.get_variable( 999 'dummy_global_step_generator', 1000 shape=[], 1001 dtype=global_step.dtype.base_dtype, 1002 initializer=init_ops.zeros_initializer(), 1003 trainable=False, 1004 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 1005 gen_update_ops += [generator_global_step.assign(global_step)] 1006 sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief)) 1007 with ops.name_scope('generator_train'): 1008 gen_train_op = training.create_train_op( 1009 total_loss=loss.generator_loss, 1010 optimizer=generator_optimizer, 1011 variables_to_train=model.generator_variables, 1012 global_step=generator_global_step, 1013 update_ops=gen_update_ops, 1014 **kwargs) 1015 1016 discriminator_global_step = None 1017 if isinstance(discriminator_optimizer, 1018 sync_replicas_optimizer.SyncReplicasOptimizer): 1019 # See comment above `generator_global_step`. 1020 discriminator_global_step = variable_scope.get_variable( 1021 'dummy_global_step_discriminator', 1022 shape=[], 1023 dtype=global_step.dtype.base_dtype, 1024 initializer=init_ops.zeros_initializer(), 1025 trainable=False, 1026 collections=[ops.GraphKeys.GLOBAL_VARIABLES]) 1027 dis_update_ops += [discriminator_global_step.assign(global_step)] 1028 sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief)) 1029 with ops.name_scope('discriminator_train'): 1030 disc_train_op = training.create_train_op( 1031 total_loss=loss.discriminator_loss, 1032 optimizer=discriminator_optimizer, 1033 variables_to_train=model.discriminator_variables, 1034 global_step=discriminator_global_step, 1035 update_ops=dis_update_ops, 1036 **kwargs) 1037 1038 return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc, 1039 sync_hooks) 1040 1041 1042 # TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive 1043 # Image Compression` (https://arxiv.org/abs/1705.05823) 1044 class RunTrainOpsHook(session_run_hook.SessionRunHook): 1045 """A hook to run train ops a fixed number of times.""" 1046 1047 def __init__(self, train_ops, train_steps): 1048 """Run train ops a certain number of times. 1049 1050 Args: 1051 train_ops: A train op or iterable of train ops to run. 1052 train_steps: The number of times to run the op(s). 1053 """ 1054 if not isinstance(train_ops, (list, tuple)): 1055 train_ops = [train_ops] 1056 self._train_ops = train_ops 1057 self._train_steps = train_steps 1058 1059 def before_run(self, run_context): 1060 for _ in range(self._train_steps): 1061 run_context.session.run(self._train_ops) 1062 1063 1064 def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): 1065 """Returns a hooks function for sequential GAN training. 1066 1067 Args: 1068 train_steps: A `GANTrainSteps` tuple that determines how many generator 1069 and discriminator training steps to take. 1070 1071 Returns: 1072 A function that takes a GANTrainOps tuple and returns a list of hooks. 1073 """ 1074 1075 def get_hooks(train_ops): 1076 generator_hook = RunTrainOpsHook(train_ops.generator_train_op, 1077 train_steps.generator_train_steps) 1078 discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op, 1079 train_steps.discriminator_train_steps) 1080 return [generator_hook, discriminator_hook] + list(train_ops.train_hooks) 1081 1082 return get_hooks 1083 1084 1085 def _num_joint_steps(train_steps): 1086 g_steps = train_steps.generator_train_steps 1087 d_steps = train_steps.discriminator_train_steps 1088 # Get the number of each type of step that should be run. 1089 num_d_and_g_steps = min(g_steps, d_steps) 1090 num_g_steps = g_steps - num_d_and_g_steps 1091 num_d_steps = d_steps - num_d_and_g_steps 1092 1093 return num_d_and_g_steps, num_g_steps, num_d_steps 1094 1095 1096 def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)): 1097 """Returns a hooks function for joint GAN training. 1098 1099 When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON 1100 ALL OPTIMIZERS TO AVOID RACE CONDITIONS. 1101 1102 The order of steps taken is: 1103 1) Combined generator and discriminator steps 1104 2) Generator only steps, if any remain 1105 3) Discriminator only steps, if any remain 1106 1107 **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates 1108 for the generator and discriminator simultaneously whenever possible. This 1109 reduces the number of `tf.Session` calls, and can also change the training 1110 semantics. 1111 1112 To illustrate the difference look at the following example: 1113 1114 `train_steps=namedtuples.GANTrainSteps(3, 5)` will cause 1115 `get_sequential_train_hooks` to make 8 session calls: 1116 1) 3 generator steps 1117 2) 5 discriminator steps 1118 1119 In contrast, `get_joint_train_steps` will make 5 session calls: 1120 1) 3 generator + discriminator steps 1121 2) 2 discriminator steps 1122 1123 Args: 1124 train_steps: A `GANTrainSteps` tuple that determines how many generator 1125 and discriminator training steps to take. 1126 1127 Returns: 1128 A function that takes a GANTrainOps tuple and returns a list of hooks. 1129 """ 1130 num_d_and_g_steps, num_g_steps, num_d_steps = _num_joint_steps(train_steps) 1131 1132 def get_hooks(train_ops): 1133 g_op = train_ops.generator_train_op 1134 d_op = train_ops.discriminator_train_op 1135 1136 joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps) 1137 g_hook = RunTrainOpsHook(g_op, num_g_steps) 1138 d_hook = RunTrainOpsHook(d_op, num_d_steps) 1139 1140 return [joint_hook, g_hook, d_hook] + list(train_ops.train_hooks) 1141 1142 return get_hooks 1143 1144 1145 # TODO(joelshor): This function currently returns the global step. Find a 1146 # good way for it to return the generator, discriminator, and final losses. 1147 def gan_train(train_ops, 1148 logdir, 1149 get_hooks_fn=get_sequential_train_hooks(), 1150 master='', 1151 is_chief=True, 1152 scaffold=None, 1153 hooks=None, 1154 chief_only_hooks=None, 1155 save_checkpoint_secs=600, 1156 save_summaries_steps=100, 1157 config=None): 1158 """A wrapper around `contrib.training.train` that uses GAN hooks. 1159 1160 Args: 1161 train_ops: A GANTrainOps named tuple. 1162 logdir: The directory where the graph and checkpoints are saved. 1163 get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list 1164 of hooks. 1165 master: The URL of the master. 1166 is_chief: Specifies whether or not the training is being run by the primary 1167 replica during replica training. 1168 scaffold: An tf.train.Scaffold instance. 1169 hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the 1170 training loop. 1171 chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run 1172 inside the training loop for the chief trainer only. 1173 save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved 1174 using a default checkpoint saver. If `save_checkpoint_secs` is set to 1175 `None`, then the default checkpoint saver isn't used. 1176 save_summaries_steps: The frequency, in number of global steps, that the 1177 summaries are written to disk using a default summary saver. If 1178 `save_summaries_steps` is set to `None`, then the default summary saver 1179 isn't used. 1180 config: An instance of `tf.ConfigProto`. 1181 1182 Returns: 1183 Output of the call to `training.train`. 1184 """ 1185 new_hooks = get_hooks_fn(train_ops) 1186 if hooks is not None: 1187 hooks = list(hooks) + list(new_hooks) 1188 else: 1189 hooks = new_hooks 1190 return training.train( 1191 train_ops.global_step_inc_op, 1192 logdir, 1193 master=master, 1194 is_chief=is_chief, 1195 scaffold=scaffold, 1196 hooks=hooks, 1197 chief_only_hooks=chief_only_hooks, 1198 save_checkpoint_secs=save_checkpoint_secs, 1199 save_summaries_steps=save_summaries_steps, 1200 config=config) 1201 1202 1203 def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)): 1204 """Returns a thin wrapper around slim.learning.train_step, for GANs. 1205 1206 This function is to provide support for the Supervisor. For new code, please 1207 use `MonitoredSession` and `get_sequential_train_hooks`. 1208 1209 Args: 1210 train_steps: A `GANTrainSteps` tuple that determines how many generator 1211 and discriminator training steps to take. 1212 1213 Returns: 1214 A function that can be used for `train_step_fn` for GANs. 1215 """ 1216 1217 def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs): 1218 """A thin wrapper around slim.learning.train_step, for GANs. 1219 1220 Args: 1221 sess: A Tensorflow session. 1222 train_ops: A GANTrainOps tuple of train ops to run. 1223 global_step: The global step. 1224 train_step_kwargs: Dictionary controlling `train_step` behavior. 1225 1226 Returns: 1227 A scalar final loss and a bool whether or not the train loop should stop. 1228 """ 1229 # Only run `should_stop` at the end, if required. Make a local copy of 1230 # `train_step_kwargs`, if necessary, so as not to modify the caller's 1231 # dictionary. 1232 should_stop_op, train_kwargs = None, train_step_kwargs 1233 if 'should_stop' in train_step_kwargs: 1234 should_stop_op = train_step_kwargs['should_stop'] 1235 train_kwargs = train_step_kwargs.copy() 1236 del train_kwargs['should_stop'] 1237 1238 # Run generator training steps. 1239 gen_loss = 0 1240 for _ in range(train_steps.generator_train_steps): 1241 cur_gen_loss, _ = slim_learning.train_step( 1242 sess, train_ops.generator_train_op, global_step, train_kwargs) 1243 gen_loss += cur_gen_loss 1244 1245 # Run discriminator training steps. 1246 dis_loss = 0 1247 for _ in range(train_steps.discriminator_train_steps): 1248 cur_dis_loss, _ = slim_learning.train_step( 1249 sess, train_ops.discriminator_train_op, global_step, train_kwargs) 1250 dis_loss += cur_dis_loss 1251 1252 sess.run(train_ops.global_step_inc_op) 1253 1254 # Run the `should_stop` op after the global step has been incremented, so 1255 # that the `should_stop` aligns with the proper `global_step` count. 1256 if should_stop_op is not None: 1257 should_stop = sess.run(should_stop_op) 1258 else: 1259 should_stop = False 1260 1261 return gen_loss + dis_loss, should_stop 1262 1263 return sequential_train_steps 1264 1265 1266 # Helpers 1267 1268 1269 def _convert_tensor_or_l_or_d(tensor_or_l_or_d): 1270 """Convert input, list of inputs, or dictionary of inputs to Tensors.""" 1271 if isinstance(tensor_or_l_or_d, (list, tuple)): 1272 return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d] 1273 elif isinstance(tensor_or_l_or_d, dict): 1274 return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()} 1275 else: 1276 return ops.convert_to_tensor(tensor_or_l_or_d) 1277 1278 1279 def _validate_distributions(distributions_l, noise_l): 1280 if not isinstance(distributions_l, (tuple, list)): 1281 raise ValueError('`predicted_distributions` must be a list. Instead, found ' 1282 '%s.' % type(distributions_l)) 1283 if len(distributions_l) != len(noise_l): 1284 raise ValueError('Length of `predicted_distributions` %i must be the same ' 1285 'as the length of structured noise %i.' % 1286 (len(distributions_l), len(noise_l))) 1287 1288 1289 def _validate_acgan_discriminator_outputs(discriminator_output): 1290 try: 1291 a, b = discriminator_output 1292 except (TypeError, ValueError): 1293 raise TypeError( 1294 'A discriminator function for ACGAN must output a tuple ' 1295 'consisting of (discrimination logits, classification logits).') 1296 return a, b 1297 1298 1299 def _generate_stargan_random_domain_target(batch_size, num_domains): 1300 """Generate random domain label. 1301 1302 Args: 1303 batch_size: (int) Number of random domain label. 1304 num_domains: (int) Number of domains representing with the label. 1305 1306 Returns: 1307 Tensor of shape (batch_size, num_domains) representing random label. 1308 """ 1309 domain_idx = random_ops.random_uniform( 1310 [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32) 1311 1312 return array_ops.one_hot(domain_idx, num_domains) 1313