Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """The TF-GAN project provides a lightweight GAN training/testing framework.
     16 
     17 This file contains the core helper functions to create and train a GAN model.
     18 See the README or examples in `tensorflow_models` for details on how to use.
     19 
     20 TF-GAN training occurs in four steps:
     21 1) Create a model
     22 2) Add a loss
     23 3) Create train ops
     24 4) Run the train ops
     25 
     26 The functions in this file are organized around these four steps. Each function
     27 corresponds to one of the steps.
     28 """
     29 
     30 from __future__ import absolute_import
     31 from __future__ import division
     32 from __future__ import print_function
     33 
     34 from tensorflow.contrib.framework.python.ops import variables as variables_lib
     35 from tensorflow.contrib.gan.python import losses as tfgan_losses
     36 from tensorflow.contrib.gan.python import namedtuples
     37 from tensorflow.contrib.gan.python.losses.python import losses_impl as tfgan_losses_impl
     38 from tensorflow.contrib.slim.python.slim import learning as slim_learning
     39 from tensorflow.contrib.training.python.training import training
     40 from tensorflow.python.framework import dtypes
     41 from tensorflow.python.framework import ops
     42 from tensorflow.python.ops import array_ops
     43 from tensorflow.python.ops import check_ops
     44 from tensorflow.python.ops import init_ops
     45 from tensorflow.python.ops import math_ops
     46 from tensorflow.python.ops import random_ops
     47 from tensorflow.python.ops import variable_scope
     48 from tensorflow.python.ops.losses import losses
     49 from tensorflow.python.summary import summary
     50 from tensorflow.python.training import session_run_hook
     51 from tensorflow.python.training import sync_replicas_optimizer
     52 from tensorflow.python.training import training_util
     53 
     54 __all__ = [
     55     'gan_model',
     56     'infogan_model',
     57     'acgan_model',
     58     'cyclegan_model',
     59     'stargan_model',
     60     'gan_loss',
     61     'cyclegan_loss',
     62     'stargan_loss',
     63     'gan_train_ops',
     64     'gan_train',
     65     'get_sequential_train_hooks',
     66     'get_joint_train_hooks',
     67     'get_sequential_train_steps',
     68     'RunTrainOpsHook',
     69 ]
     70 
     71 
     72 def gan_model(
     73     # Lambdas defining models.
     74     generator_fn,
     75     discriminator_fn,
     76     # Real data and conditioning.
     77     real_data,
     78     generator_inputs,
     79     # Optional scopes.
     80     generator_scope='Generator',
     81     discriminator_scope='Discriminator',
     82     # Options.
     83     check_shapes=True):
     84   """Returns GAN model outputs and variables.
     85 
     86   Args:
     87     generator_fn: A python lambda that takes `generator_inputs` as inputs and
     88       returns the outputs of the GAN generator.
     89     discriminator_fn: A python lambda that takes `real_data`/`generated data`
     90       and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
     91     real_data: A Tensor representing the real data.
     92     generator_inputs: A Tensor or list of Tensors to the generator. In the
     93       vanilla GAN case, this might be a single noise Tensor. In the conditional
     94       GAN case, this might be the generator's conditioning.
     95     generator_scope: Optional generator variable scope. Useful if you want to
     96       reuse a subgraph that has already been created.
     97     discriminator_scope: Optional discriminator variable scope. Useful if you
     98       want to reuse a subgraph that has already been created.
     99     check_shapes: If `True`, check that generator produces Tensors that are the
    100       same shape as real data. Otherwise, skip this check.
    101 
    102   Returns:
    103     A GANModel namedtuple.
    104 
    105   Raises:
    106     ValueError: If the generator outputs a Tensor that isn't the same shape as
    107       `real_data`.
    108   """
    109   # Create models
    110   with variable_scope.variable_scope(generator_scope) as gen_scope:
    111     generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
    112     generated_data = generator_fn(generator_inputs)
    113   with variable_scope.variable_scope(discriminator_scope) as dis_scope:
    114     discriminator_gen_outputs = discriminator_fn(generated_data,
    115                                                  generator_inputs)
    116   with variable_scope.variable_scope(dis_scope, reuse=True):
    117     real_data = _convert_tensor_or_l_or_d(real_data)
    118     discriminator_real_outputs = discriminator_fn(real_data, generator_inputs)
    119 
    120   if check_shapes:
    121     if not generated_data.shape.is_compatible_with(real_data.shape):
    122       raise ValueError(
    123           'Generator output shape (%s) must be the same shape as real data '
    124           '(%s).' % (generated_data.shape, real_data.shape))
    125 
    126   # Get model-specific variables.
    127   generator_variables = variables_lib.get_trainable_variables(gen_scope)
    128   discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
    129 
    130   return namedtuples.GANModel(
    131       generator_inputs, generated_data, generator_variables, gen_scope,
    132       generator_fn, real_data, discriminator_real_outputs,
    133       discriminator_gen_outputs, discriminator_variables, dis_scope,
    134       discriminator_fn)
    135 
    136 
    137 def infogan_model(
    138     # Lambdas defining models.
    139     generator_fn,
    140     discriminator_fn,
    141     # Real data and conditioning.
    142     real_data,
    143     unstructured_generator_inputs,
    144     structured_generator_inputs,
    145     # Optional scopes.
    146     generator_scope='Generator',
    147     discriminator_scope='Discriminator'):
    148   """Returns an InfoGAN model outputs and variables.
    149 
    150   See https://arxiv.org/abs/1606.03657 for more details.
    151 
    152   Args:
    153     generator_fn: A python lambda that takes a list of Tensors as inputs and
    154       returns the outputs of the GAN generator.
    155     discriminator_fn: A python lambda that takes `real_data`/`generated data`
    156       and `generator_inputs`. Outputs a 2-tuple of (logits, distribution_list).
    157       `logits` are in the range [-inf, inf], and `distribution_list` is a list
    158       of Tensorflow distributions representing the predicted noise distribution
    159       of the ith structure noise.
    160     real_data: A Tensor representing the real data.
    161     unstructured_generator_inputs: A list of Tensors to the generator.
    162       These tensors represent the unstructured noise or conditioning.
    163     structured_generator_inputs: A list of Tensors to the generator.
    164       These tensors must have high mutual information with the recognizer.
    165     generator_scope: Optional generator variable scope. Useful if you want to
    166       reuse a subgraph that has already been created.
    167     discriminator_scope: Optional discriminator variable scope. Useful if you
    168       want to reuse a subgraph that has already been created.
    169 
    170   Returns:
    171     An InfoGANModel namedtuple.
    172 
    173   Raises:
    174     ValueError: If the generator outputs a Tensor that isn't the same shape as
    175       `real_data`.
    176     ValueError: If the discriminator output is malformed.
    177   """
    178   # Create models
    179   with variable_scope.variable_scope(generator_scope) as gen_scope:
    180     unstructured_generator_inputs = _convert_tensor_or_l_or_d(
    181         unstructured_generator_inputs)
    182     structured_generator_inputs = _convert_tensor_or_l_or_d(
    183         structured_generator_inputs)
    184     generator_inputs = (
    185         unstructured_generator_inputs + structured_generator_inputs)
    186     generated_data = generator_fn(generator_inputs)
    187   with variable_scope.variable_scope(discriminator_scope) as disc_scope:
    188     dis_gen_outputs, predicted_distributions = discriminator_fn(
    189         generated_data, generator_inputs)
    190   _validate_distributions(predicted_distributions, structured_generator_inputs)
    191   with variable_scope.variable_scope(disc_scope, reuse=True):
    192     real_data = ops.convert_to_tensor(real_data)
    193     dis_real_outputs, _ = discriminator_fn(real_data, generator_inputs)
    194 
    195   if not generated_data.get_shape().is_compatible_with(real_data.get_shape()):
    196     raise ValueError(
    197         'Generator output shape (%s) must be the same shape as real data '
    198         '(%s).' % (generated_data.get_shape(), real_data.get_shape()))
    199 
    200   # Get model-specific variables.
    201   generator_variables = variables_lib.get_trainable_variables(gen_scope)
    202   discriminator_variables = variables_lib.get_trainable_variables(disc_scope)
    203 
    204   return namedtuples.InfoGANModel(
    205       generator_inputs,
    206       generated_data,
    207       generator_variables,
    208       gen_scope,
    209       generator_fn,
    210       real_data,
    211       dis_real_outputs,
    212       dis_gen_outputs,
    213       discriminator_variables,
    214       disc_scope,
    215       lambda x, y: discriminator_fn(x, y)[0],  # conform to non-InfoGAN API
    216       structured_generator_inputs,
    217       predicted_distributions,
    218       discriminator_fn)
    219 
    220 
    221 def acgan_model(
    222     # Lambdas defining models.
    223     generator_fn,
    224     discriminator_fn,
    225     # Real data and conditioning.
    226     real_data,
    227     generator_inputs,
    228     one_hot_labels,
    229     # Optional scopes.
    230     generator_scope='Generator',
    231     discriminator_scope='Discriminator',
    232     # Options.
    233     check_shapes=True):
    234   """Returns an ACGANModel contains all the pieces needed for ACGAN training.
    235 
    236   The `acgan_model` is the same as the `gan_model` with the only difference
    237   being that the discriminator additionally outputs logits to classify the input
    238   (real or generated).
    239   Therefore, an explicit field holding one_hot_labels is necessary, as well as a
    240   discriminator_fn that outputs a 2-tuple holding the logits for real/fake and
    241   classification.
    242 
    243   See https://arxiv.org/abs/1610.09585 for more details.
    244 
    245   Args:
    246     generator_fn: A python lambda that takes `generator_inputs` as inputs and
    247       returns the outputs of the GAN generator.
    248     discriminator_fn: A python lambda that takes `real_data`/`generated data`
    249       and `generator_inputs`. Outputs a tuple consisting of two Tensors:
    250         (1) real/fake logits in the range [-inf, inf]
    251         (2) classification logits in the range [-inf, inf]
    252     real_data: A Tensor representing the real data.
    253     generator_inputs: A Tensor or list of Tensors to the generator. In the
    254       vanilla GAN case, this might be a single noise Tensor. In the conditional
    255       GAN case, this might be the generator's conditioning.
    256     one_hot_labels: A Tensor holding one-hot-labels for the batch. Needed by
    257       acgan_loss.
    258     generator_scope: Optional generator variable scope. Useful if you want to
    259       reuse a subgraph that has already been created.
    260     discriminator_scope: Optional discriminator variable scope. Useful if you
    261       want to reuse a subgraph that has already been created.
    262     check_shapes: If `True`, check that generator produces Tensors that are the
    263       same shape as real data. Otherwise, skip this check.
    264 
    265   Returns:
    266     A ACGANModel namedtuple.
    267 
    268   Raises:
    269     ValueError: If the generator outputs a Tensor that isn't the same shape as
    270       `real_data`.
    271     TypeError: If the discriminator does not output a tuple consisting of
    272     (discrimination logits, classification logits).
    273   """
    274   # Create models
    275   with variable_scope.variable_scope(generator_scope) as gen_scope:
    276     generator_inputs = _convert_tensor_or_l_or_d(generator_inputs)
    277     generated_data = generator_fn(generator_inputs)
    278   with variable_scope.variable_scope(discriminator_scope) as dis_scope:
    279     with ops.name_scope(dis_scope.name + '/generated/'):
    280       (discriminator_gen_outputs, discriminator_gen_classification_logits
    281       ) = _validate_acgan_discriminator_outputs(
    282           discriminator_fn(generated_data, generator_inputs))
    283   with variable_scope.variable_scope(dis_scope, reuse=True):
    284     with ops.name_scope(dis_scope.name + '/real/'):
    285       real_data = ops.convert_to_tensor(real_data)
    286       (discriminator_real_outputs, discriminator_real_classification_logits
    287       ) = _validate_acgan_discriminator_outputs(
    288           discriminator_fn(real_data, generator_inputs))
    289   if check_shapes:
    290     if not generated_data.shape.is_compatible_with(real_data.shape):
    291       raise ValueError(
    292           'Generator output shape (%s) must be the same shape as real data '
    293           '(%s).' % (generated_data.shape, real_data.shape))
    294 
    295   # Get model-specific variables.
    296   generator_variables = variables_lib.get_trainable_variables(gen_scope)
    297   discriminator_variables = variables_lib.get_trainable_variables(dis_scope)
    298 
    299   return namedtuples.ACGANModel(
    300       generator_inputs, generated_data, generator_variables, gen_scope,
    301       generator_fn, real_data, discriminator_real_outputs,
    302       discriminator_gen_outputs, discriminator_variables, dis_scope,
    303       discriminator_fn, one_hot_labels,
    304       discriminator_real_classification_logits,
    305       discriminator_gen_classification_logits)
    306 
    307 
    308 def cyclegan_model(
    309     # Lambdas defining models.
    310     generator_fn,
    311     discriminator_fn,
    312     # data X and Y.
    313     data_x,
    314     data_y,
    315     # Optional scopes.
    316     generator_scope='Generator',
    317     discriminator_scope='Discriminator',
    318     model_x2y_scope='ModelX2Y',
    319     model_y2x_scope='ModelY2X',
    320     # Options.
    321     check_shapes=True):
    322   """Returns a CycleGAN model outputs and variables.
    323 
    324   See https://arxiv.org/abs/1703.10593 for more details.
    325 
    326   Args:
    327     generator_fn: A python lambda that takes `data_x` or `data_y` as inputs and
    328       returns the outputs of the GAN generator.
    329     discriminator_fn: A python lambda that takes `real_data`/`generated data`
    330       and `generator_inputs`. Outputs a Tensor in the range [-inf, inf].
    331     data_x: A `Tensor` of dataset X. Must be the same shape as `data_y`.
    332     data_y: A `Tensor` of dataset Y. Must be the same shape as `data_x`.
    333     generator_scope: Optional generator variable scope. Useful if you want to
    334       reuse a subgraph that has already been created. Defaults to 'Generator'.
    335     discriminator_scope: Optional discriminator variable scope. Useful if you
    336       want to reuse a subgraph that has already been created. Defaults to
    337       'Discriminator'.
    338     model_x2y_scope: Optional variable scope for model x2y variables. Defaults
    339       to 'ModelX2Y'.
    340     model_y2x_scope: Optional variable scope for model y2x variables. Defaults
    341       to 'ModelY2X'.
    342     check_shapes: If `True`, check that generator produces Tensors that are the
    343       same shape as `data_x` (`data_y`). Otherwise, skip this check.
    344 
    345   Returns:
    346     A `CycleGANModel` namedtuple.
    347 
    348   Raises:
    349     ValueError: If `check_shapes` is True and `data_x` or the generator output
    350       does not have the same shape as `data_y`.
    351   """
    352 
    353   # Create models.
    354   def _define_partial_model(input_data, output_data):
    355     return gan_model(
    356         generator_fn=generator_fn,
    357         discriminator_fn=discriminator_fn,
    358         real_data=output_data,
    359         generator_inputs=input_data,
    360         generator_scope=generator_scope,
    361         discriminator_scope=discriminator_scope,
    362         check_shapes=check_shapes)
    363 
    364   with variable_scope.variable_scope(model_x2y_scope):
    365     model_x2y = _define_partial_model(data_x, data_y)
    366   with variable_scope.variable_scope(model_y2x_scope):
    367     model_y2x = _define_partial_model(data_y, data_x)
    368 
    369   with variable_scope.variable_scope(model_y2x.generator_scope, reuse=True):
    370     reconstructed_x = model_y2x.generator_fn(model_x2y.generated_data)
    371   with variable_scope.variable_scope(model_x2y.generator_scope, reuse=True):
    372     reconstructed_y = model_x2y.generator_fn(model_y2x.generated_data)
    373 
    374   return namedtuples.CycleGANModel(model_x2y, model_y2x, reconstructed_x,
    375                                    reconstructed_y)
    376 
    377 
    378 def stargan_model(generator_fn,
    379                   discriminator_fn,
    380                   input_data,
    381                   input_data_domain_label,
    382                   generator_scope='Generator',
    383                   discriminator_scope='Discriminator'):
    384   """Returns a StarGAN model outputs and variables.
    385 
    386   See https://arxiv.org/abs/1711.09020 for more details.
    387 
    388   Args:
    389     generator_fn: A python lambda that takes `inputs` and `targets` as inputs
    390       and returns 'generated_data' as the transformed version of `input` based
    391       on the `target`. `input` has shape (n, h, w, c), `targets` has shape (n,
    392       num_domains), and `generated_data` has the same shape as `input`.
    393     discriminator_fn: A python lambda that takes `inputs` and `num_domains` as
    394       inputs and returns a tuple (`source_prediction`, `domain_prediction`).
    395       `source_prediction` represents the source(real/generated) prediction by
    396       the discriminator, and `domain_prediction` represents the domain
    397       prediction/classification by the discriminator. `source_prediction` has
    398       shape (n) and `domain_prediction` has shape (n, num_domains).
    399     input_data: Tensor or a list of tensor of shape (n, h, w, c) representing
    400       the real input images.
    401     input_data_domain_label: Tensor or a list of tensor of shape (batch_size,
    402       num_domains) representing the domain label associated with the real
    403       images.
    404     generator_scope: Optional generator variable scope. Useful if you want to
    405       reuse a subgraph that has already been created.
    406     discriminator_scope: Optional discriminator variable scope. Useful if you
    407       want to reuse a subgraph that has already been created.
    408 
    409   Returns:
    410     StarGANModel nametuple return the tensor that are needed to compute the
    411     loss.
    412 
    413   Raises:
    414     ValueError: If the shape of `input_data_domain_label` is not rank 2 or fully
    415     defined in every dimensions.
    416   """
    417 
    418   # Convert to tensor.
    419   input_data = _convert_tensor_or_l_or_d(input_data)
    420   input_data_domain_label = _convert_tensor_or_l_or_d(input_data_domain_label)
    421 
    422   # Convert list of tensor to a single tensor if applicable.
    423   if isinstance(input_data, (list, tuple)):
    424     input_data = array_ops.concat(
    425         [ops.convert_to_tensor(x) for x in input_data], 0)
    426   if isinstance(input_data_domain_label, (list, tuple)):
    427     input_data_domain_label = array_ops.concat(
    428         [ops.convert_to_tensor(x) for x in input_data_domain_label], 0)
    429 
    430   # Get batch_size, num_domains from the labels.
    431   input_data_domain_label.shape.assert_has_rank(2)
    432   input_data_domain_label.shape.assert_is_fully_defined()
    433   batch_size, num_domains = input_data_domain_label.shape.as_list()
    434 
    435   # Transform input_data to random target domains.
    436   with variable_scope.variable_scope(generator_scope) as generator_scope:
    437     generated_data_domain_target = _generate_stargan_random_domain_target(
    438         batch_size, num_domains)
    439     generated_data = generator_fn(input_data, generated_data_domain_target)
    440 
    441   # Transform generated_data back to the original input_data domain.
    442   with variable_scope.variable_scope(generator_scope, reuse=True):
    443     reconstructed_data = generator_fn(generated_data, input_data_domain_label)
    444 
    445   # Predict source and domain for the generated_data using the discriminator.
    446   with variable_scope.variable_scope(
    447       discriminator_scope) as discriminator_scope:
    448     disc_gen_data_source_pred, disc_gen_data_domain_pred = discriminator_fn(
    449         generated_data, num_domains)
    450 
    451   # Predict source and domain for the input_data using the discriminator.
    452   with variable_scope.variable_scope(discriminator_scope, reuse=True):
    453     disc_input_data_source_pred, disc_input_data_domain_pred = discriminator_fn(
    454         input_data, num_domains)
    455 
    456   # Collect trainable variables from the neural networks.
    457   generator_variables = variables_lib.get_trainable_variables(generator_scope)
    458   discriminator_variables = variables_lib.get_trainable_variables(
    459       discriminator_scope)
    460 
    461   # Create the StarGANModel namedtuple.
    462   return namedtuples.StarGANModel(
    463       input_data=input_data,
    464       input_data_domain_label=input_data_domain_label,
    465       generated_data=generated_data,
    466       generated_data_domain_target=generated_data_domain_target,
    467       reconstructed_data=reconstructed_data,
    468       discriminator_input_data_source_predication=disc_input_data_source_pred,
    469       discriminator_generated_data_source_predication=disc_gen_data_source_pred,
    470       discriminator_input_data_domain_predication=disc_input_data_domain_pred,
    471       discriminator_generated_data_domain_predication=disc_gen_data_domain_pred,
    472       generator_variables=generator_variables,
    473       generator_scope=generator_scope,
    474       generator_fn=generator_fn,
    475       discriminator_variables=discriminator_variables,
    476       discriminator_scope=discriminator_scope,
    477       discriminator_fn=discriminator_fn)
    478 
    479 
    480 def _validate_aux_loss_weight(aux_loss_weight, name='aux_loss_weight'):
    481   if isinstance(aux_loss_weight, ops.Tensor):
    482     aux_loss_weight.shape.assert_is_compatible_with([])
    483     with ops.control_dependencies(
    484         [check_ops.assert_greater_equal(aux_loss_weight, 0.0)]):
    485       aux_loss_weight = array_ops.identity(aux_loss_weight)
    486   elif aux_loss_weight is not None and aux_loss_weight < 0:
    487     raise ValueError('`%s` must be greater than 0. Instead, was %s' %
    488                      (name, aux_loss_weight))
    489   return aux_loss_weight
    490 
    491 
    492 def _use_aux_loss(aux_loss_weight):
    493   if aux_loss_weight is not None:
    494     if not isinstance(aux_loss_weight, ops.Tensor):
    495       return aux_loss_weight > 0
    496     else:
    497       return True
    498   else:
    499     return False
    500 
    501 
    502 def _tensor_pool_adjusted_model(model, tensor_pool_fn):
    503   """Adjusts model using `tensor_pool_fn`.
    504 
    505   Args:
    506     model: A GANModel tuple.
    507     tensor_pool_fn: A function that takes (generated_data, generator_inputs),
    508       stores them in an internal pool and returns a previously stored
    509       (generated_data, generator_inputs) with some probability. For example
    510       tfgan.features.tensor_pool.
    511 
    512   Returns:
    513     A new GANModel tuple where discriminator outputs are adjusted by taking
    514     pooled generator outputs as inputs. Returns the original model if
    515     `tensor_pool_fn` is None.
    516 
    517   Raises:
    518     ValueError: If tensor pool does not support the `model`.
    519   """
    520   if isinstance(model, namedtuples.GANModel):
    521     pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
    522         (model.generator_inputs, model.generated_data))
    523     with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
    524       dis_gen_outputs = model.discriminator_fn(pooled_generated_data,
    525                                                pooled_generator_inputs)
    526     return model._replace(
    527         generator_inputs=pooled_generator_inputs,
    528         generated_data=pooled_generated_data,
    529         discriminator_gen_outputs=dis_gen_outputs)
    530   elif isinstance(model, namedtuples.ACGANModel):
    531     pooled_generator_inputs, pooled_generated_data = tensor_pool_fn(
    532         (model.generator_inputs, model.generated_data))
    533     with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
    534       (pooled_discriminator_gen_outputs,
    535        pooled_discriminator_gen_classification_logits) = model.discriminator_fn(
    536            pooled_generated_data, pooled_generator_inputs)
    537     return model._replace(
    538         generator_inputs=pooled_generator_inputs,
    539         generated_data=pooled_generated_data,
    540         discriminator_gen_outputs=pooled_discriminator_gen_outputs,
    541         discriminator_gen_classification_logits=
    542         pooled_discriminator_gen_classification_logits)
    543   elif isinstance(model, namedtuples.InfoGANModel):
    544     pooled_generator_inputs, pooled_generated_data, pooled_structured_input = (
    545         tensor_pool_fn((model.generator_inputs, model.generated_data,
    546                         model.structured_generator_inputs)))
    547     with variable_scope.variable_scope(model.discriminator_scope, reuse=True):
    548       (pooled_discriminator_gen_outputs,
    549        pooled_predicted_distributions) = model.discriminator_and_aux_fn(
    550            pooled_generated_data, pooled_generator_inputs)
    551     return model._replace(
    552         generator_inputs=pooled_generator_inputs,
    553         generated_data=pooled_generated_data,
    554         structured_generator_inputs=pooled_structured_input,
    555         discriminator_gen_outputs=pooled_discriminator_gen_outputs,
    556         predicted_distributions=pooled_predicted_distributions)
    557   else:
    558     raise ValueError('Tensor pool does not support `model`: %s.' % type(model))
    559 
    560 
    561 def gan_loss(
    562     # GANModel.
    563     model,
    564     # Loss functions.
    565     generator_loss_fn=tfgan_losses.wasserstein_generator_loss,
    566     discriminator_loss_fn=tfgan_losses.wasserstein_discriminator_loss,
    567     # Auxiliary losses.
    568     gradient_penalty_weight=None,
    569     gradient_penalty_epsilon=1e-10,
    570     gradient_penalty_target=1.0,
    571     gradient_penalty_one_sided=False,
    572     mutual_information_penalty_weight=None,
    573     aux_cond_generator_weight=None,
    574     aux_cond_discriminator_weight=None,
    575     tensor_pool_fn=None,
    576     # Options.
    577     add_summaries=True):
    578   """Returns losses necessary to train generator and discriminator.
    579 
    580   Args:
    581     model: A GANModel tuple.
    582     generator_loss_fn: The loss function on the generator. Takes a GANModel
    583       tuple.
    584     discriminator_loss_fn: The loss function on the discriminator. Takes a
    585       GANModel tuple.
    586     gradient_penalty_weight: If not `None`, must be a non-negative Python number
    587       or Tensor indicating how much to weight the gradient penalty. See
    588       https://arxiv.org/pdf/1704.00028.pdf for more details.
    589     gradient_penalty_epsilon: If `gradient_penalty_weight` is not None, the
    590       small positive value used by the gradient penalty function for numerical
    591       stability. Note some applications will need to increase this value to
    592       avoid NaNs.
    593     gradient_penalty_target: If `gradient_penalty_weight` is not None, a Python
    594       number or `Tensor` indicating the target value of gradient norm. See the
    595       CIFAR10 section of https://arxiv.org/abs/1710.10196. Defaults to 1.0.
    596     gradient_penalty_one_sided: If `True`, penalty proposed in
    597       https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
    598     mutual_information_penalty_weight: If not `None`, must be a non-negative
    599       Python number or Tensor indicating how much to weight the mutual
    600       information penalty. See https://arxiv.org/abs/1606.03657 for more
    601       details.
    602     aux_cond_generator_weight: If not None: add a classification loss as in
    603       https://arxiv.org/abs/1610.09585
    604     aux_cond_discriminator_weight: If not None: add a classification loss as in
    605       https://arxiv.org/abs/1610.09585
    606     tensor_pool_fn: A function that takes (generated_data, generator_inputs),
    607       stores them in an internal pool and returns previous stored
    608       (generated_data, generator_inputs). For example
    609       `tf.gan.features.tensor_pool`. Defaults to None (not using tensor pool).
    610     add_summaries: Whether or not to add summaries for the losses.
    611 
    612   Returns:
    613     A GANLoss 2-tuple of (generator_loss, discriminator_loss). Includes
    614     regularization losses.
    615 
    616   Raises:
    617     ValueError: If any of the auxiliary loss weights is provided and negative.
    618     ValueError: If `mutual_information_penalty_weight` is provided, but the
    619       `model` isn't an `InfoGANModel`.
    620   """
    621   # Validate arguments.
    622   gradient_penalty_weight = _validate_aux_loss_weight(
    623       gradient_penalty_weight, 'gradient_penalty_weight')
    624   mutual_information_penalty_weight = _validate_aux_loss_weight(
    625       mutual_information_penalty_weight, 'infogan_weight')
    626   aux_cond_generator_weight = _validate_aux_loss_weight(
    627       aux_cond_generator_weight, 'aux_cond_generator_weight')
    628   aux_cond_discriminator_weight = _validate_aux_loss_weight(
    629       aux_cond_discriminator_weight, 'aux_cond_discriminator_weight')
    630 
    631   # Verify configuration for mutual information penalty
    632   if (_use_aux_loss(mutual_information_penalty_weight) and
    633       not isinstance(model, namedtuples.InfoGANModel)):
    634     raise ValueError(
    635         'When `mutual_information_penalty_weight` is provided, `model` must be '
    636         'an `InfoGANModel`. Instead, was %s.' % type(model))
    637 
    638   # Verify configuration for mutual auxiliary condition loss (ACGAN).
    639   if ((_use_aux_loss(aux_cond_generator_weight) or
    640        _use_aux_loss(aux_cond_discriminator_weight)) and
    641       not isinstance(model, namedtuples.ACGANModel)):
    642     raise ValueError(
    643         'When `aux_cond_generator_weight` or `aux_cond_discriminator_weight` '
    644         'is provided, `model` must be an `ACGANModel`. Instead, was %s.' %
    645         type(model))
    646 
    647   # Optionally create pooled model.
    648   if tensor_pool_fn:
    649     pooled_model = _tensor_pool_adjusted_model(model, tensor_pool_fn)
    650   else:
    651     pooled_model = model
    652 
    653   # Create standard losses.
    654   gen_loss = generator_loss_fn(model, add_summaries=add_summaries)
    655   dis_loss = discriminator_loss_fn(pooled_model, add_summaries=add_summaries)
    656 
    657   # Add optional extra losses.
    658   if _use_aux_loss(gradient_penalty_weight):
    659     gp_loss = tfgan_losses.wasserstein_gradient_penalty(
    660         pooled_model,
    661         epsilon=gradient_penalty_epsilon,
    662         target=gradient_penalty_target,
    663         one_sided=gradient_penalty_one_sided,
    664         add_summaries=add_summaries)
    665     dis_loss += gradient_penalty_weight * gp_loss
    666   if _use_aux_loss(mutual_information_penalty_weight):
    667     gen_info_loss = tfgan_losses.mutual_information_penalty(
    668         model, add_summaries=add_summaries)
    669     if tensor_pool_fn is None:
    670       dis_info_loss = gen_info_loss
    671     else:
    672       dis_info_loss = tfgan_losses.mutual_information_penalty(
    673           pooled_model, add_summaries=add_summaries)
    674     gen_loss += mutual_information_penalty_weight * gen_info_loss
    675     dis_loss += mutual_information_penalty_weight * dis_info_loss
    676   if _use_aux_loss(aux_cond_generator_weight):
    677     ac_gen_loss = tfgan_losses.acgan_generator_loss(
    678         model, add_summaries=add_summaries)
    679     gen_loss += aux_cond_generator_weight * ac_gen_loss
    680   if _use_aux_loss(aux_cond_discriminator_weight):
    681     ac_disc_loss = tfgan_losses.acgan_discriminator_loss(
    682         pooled_model, add_summaries=add_summaries)
    683     dis_loss += aux_cond_discriminator_weight * ac_disc_loss
    684   # Gathers auxiliary losses.
    685   if model.generator_scope:
    686     gen_reg_loss = losses.get_regularization_loss(model.generator_scope.name)
    687   else:
    688     gen_reg_loss = 0
    689   if model.discriminator_scope:
    690     dis_reg_loss = losses.get_regularization_loss(
    691         model.discriminator_scope.name)
    692   else:
    693     dis_reg_loss = 0
    694 
    695   return namedtuples.GANLoss(gen_loss + gen_reg_loss, dis_loss + dis_reg_loss)
    696 
    697 
    698 def cyclegan_loss(
    699     model,
    700     # Loss functions.
    701     generator_loss_fn=tfgan_losses.least_squares_generator_loss,
    702     discriminator_loss_fn=tfgan_losses.least_squares_discriminator_loss,
    703     # Auxiliary losses.
    704     cycle_consistency_loss_fn=tfgan_losses.cycle_consistency_loss,
    705     cycle_consistency_loss_weight=10.0,
    706     # Options
    707     **kwargs):
    708   """Returns the losses for a `CycleGANModel`.
    709 
    710   See https://arxiv.org/abs/1703.10593 for more details.
    711 
    712   Args:
    713     model: A `CycleGANModel` namedtuple.
    714     generator_loss_fn: The loss function on the generator. Takes a `GANModel`
    715       named tuple.
    716     discriminator_loss_fn: The loss function on the discriminator. Takes a
    717       `GANModel` namedtuple.
    718     cycle_consistency_loss_fn: The cycle consistency loss function. Takes a
    719       `CycleGANModel` namedtuple.
    720     cycle_consistency_loss_weight: A non-negative Python number or a scalar
    721       `Tensor` indicating how much to weigh the cycle consistency loss.
    722     **kwargs: Keyword args to pass directly to `gan_loss` to construct the loss
    723       for each partial model of `model`.
    724 
    725   Returns:
    726     A `CycleGANLoss` namedtuple.
    727 
    728   Raises:
    729     ValueError: If `model` is not a `CycleGANModel` namedtuple.
    730   """
    731   # Sanity checks.
    732   if not isinstance(model, namedtuples.CycleGANModel):
    733     raise ValueError(
    734         '`model` must be a `CycleGANModel`. Instead, was %s.' % type(model))
    735 
    736   # Defines cycle consistency loss.
    737   cycle_consistency_loss = cycle_consistency_loss_fn(
    738       model, add_summaries=kwargs.get('add_summaries', True))
    739   cycle_consistency_loss_weight = _validate_aux_loss_weight(
    740       cycle_consistency_loss_weight, 'cycle_consistency_loss_weight')
    741   aux_loss = cycle_consistency_loss_weight * cycle_consistency_loss
    742 
    743   # Defines losses for each partial model.
    744   def _partial_loss(partial_model):
    745     partial_loss = gan_loss(
    746         partial_model,
    747         generator_loss_fn=generator_loss_fn,
    748         discriminator_loss_fn=discriminator_loss_fn,
    749         **kwargs)
    750     return partial_loss._replace(generator_loss=partial_loss.generator_loss +
    751                                  aux_loss)
    752 
    753   with ops.name_scope('cyclegan_loss_x2y'):
    754     loss_x2y = _partial_loss(model.model_x2y)
    755   with ops.name_scope('cyclegan_loss_y2x'):
    756     loss_y2x = _partial_loss(model.model_y2x)
    757 
    758   return namedtuples.CycleGANLoss(loss_x2y, loss_y2x)
    759 
    760 # Begin google-internal
    761 # The four major parts can be found here: http://screen/tMRMBAohDYG.
    762 # End google-internal
    763 def stargan_loss(
    764     model,
    765     generator_loss_fn=tfgan_losses.stargan_generator_loss_wrapper(
    766         tfgan_losses_impl.wasserstein_generator_loss),
    767     discriminator_loss_fn=tfgan_losses.stargan_discriminator_loss_wrapper(
    768         tfgan_losses_impl.wasserstein_discriminator_loss),
    769     gradient_penalty_weight=10.0,
    770     gradient_penalty_epsilon=1e-10,
    771     gradient_penalty_target=1.0,
    772     gradient_penalty_one_sided=False,
    773     reconstruction_loss_fn=losses.absolute_difference,
    774     reconstruction_loss_weight=10.0,
    775     classification_loss_fn=losses.softmax_cross_entropy,
    776     classification_loss_weight=1.0,
    777     classification_one_hot=True,
    778     add_summaries=True):
    779   """StarGAN Loss.
    780 
    781   Args:
    782     model: (StarGAN) Model output of the stargan_model() function call.
    783     generator_loss_fn: The loss function on the generator. Takes a
    784       `StarGANModel` named tuple.
    785     discriminator_loss_fn: The loss function on the discriminator. Takes a
    786       `StarGANModel` namedtuple.
    787     gradient_penalty_weight: (float) Gradient penalty weight. Default to 10 per
    788       the original paper https://arxiv.org/abs/1711.09020. Set to 0 or None to
    789       turn off gradient penalty.
    790     gradient_penalty_epsilon: (float) A small positive number added for
    791       numerical stability when computing the gradient norm.
    792     gradient_penalty_target: (float, or tf.float `Tensor`) The target value of
    793       gradient norm. Defaults to 1.0.
    794     gradient_penalty_one_sided: (bool) If `True`, penalty proposed in
    795       https://arxiv.org/abs/1709.08894 is used. Defaults to `False`.
    796     reconstruction_loss_fn: The reconstruction loss function. Default to L1-norm
    797       and the function must conform to the `tf.losses` API.
    798     reconstruction_loss_weight: Reconstruction loss weight. Default to 10.0.
    799     classification_loss_fn: The loss function on the discriminator's ability to
    800       classify domain of the input. Default to one-hot softmax cross entropy
    801       loss, and the function must conform to the `tf.losses` API.
    802     classification_loss_weight: (float) Classification loss weight. Default to
    803       1.0.
    804     classification_one_hot: (bool) If the label is one hot representation.
    805       Default to True. If False, classification classification_loss_fn need to
    806       be sigmoid cross entropy loss instead.
    807     add_summaries: (bool) Add the loss to the summary
    808 
    809   Returns:
    810     GANLoss namedtuple where we have generator loss and discriminator loss.
    811 
    812   Raises:
    813     ValueError: If input StarGANModel.input_data_domain_label does not have rank
    814     2, or dimension 2 is not defined.
    815   """
    816 
    817   def _classification_loss_helper(true_labels, predict_logits, scope_name):
    818     """Classification Loss Function Helper.
    819 
    820     Args:
    821       true_labels: Tensor of shape [batch_size, num_domains] representing the
    822         label where each row is an one-hot vector.
    823       predict_logits: Tensor of shape [batch_size, num_domains] representing the
    824         predicted label logit, which is UNSCALED output from the NN.
    825       scope_name: (string) Name scope of the loss component.
    826 
    827     Returns:
    828       Single scalar tensor representing the classification loss.
    829     """
    830 
    831     with ops.name_scope(scope_name, values=(true_labels, predict_logits)):
    832 
    833       loss = classification_loss_fn(
    834           onehot_labels=true_labels, logits=predict_logits)
    835 
    836       if not classification_one_hot:
    837         loss = math_ops.reduce_sum(loss, axis=1)
    838       loss = math_ops.reduce_mean(loss)
    839 
    840       if add_summaries:
    841         summary.scalar(scope_name, loss)
    842 
    843       return loss
    844 
    845   # Check input shape.
    846   model.input_data_domain_label.shape.assert_has_rank(2)
    847   model.input_data_domain_label.shape[1:].assert_is_fully_defined()
    848 
    849   # Adversarial Loss.
    850   generator_loss = generator_loss_fn(model, add_summaries=add_summaries)
    851   discriminator_loss = discriminator_loss_fn(model, add_summaries=add_summaries)
    852 
    853   # Gradient Penalty.
    854   if _use_aux_loss(gradient_penalty_weight):
    855     gradient_penalty_fn = tfgan_losses.stargan_gradient_penalty_wrapper(
    856         tfgan_losses_impl.wasserstein_gradient_penalty)
    857     discriminator_loss += gradient_penalty_fn(
    858         model,
    859         epsilon=gradient_penalty_epsilon,
    860         target=gradient_penalty_target,
    861         one_sided=gradient_penalty_one_sided,
    862         add_summaries=add_summaries) * gradient_penalty_weight
    863 
    864   # Reconstruction Loss.
    865   reconstruction_loss = reconstruction_loss_fn(model.input_data,
    866                                                model.reconstructed_data)
    867   generator_loss += reconstruction_loss * reconstruction_loss_weight
    868   if add_summaries:
    869     summary.scalar('reconstruction_loss', reconstruction_loss)
    870 
    871   # Classification Loss.
    872   generator_loss += _classification_loss_helper(
    873       true_labels=model.generated_data_domain_target,
    874       predict_logits=model.discriminator_generated_data_domain_predication,
    875       scope_name='generator_classification_loss') * classification_loss_weight
    876   discriminator_loss += _classification_loss_helper(
    877       true_labels=model.input_data_domain_label,
    878       predict_logits=model.discriminator_input_data_domain_predication,
    879       scope_name='discriminator_classification_loss'
    880   ) * classification_loss_weight
    881 
    882   return namedtuples.GANLoss(generator_loss, discriminator_loss)
    883 
    884 
    885 def _get_update_ops(kwargs, gen_scope, dis_scope, check_for_unused_ops=True):
    886   """Gets generator and discriminator update ops.
    887 
    888   Args:
    889     kwargs: A dictionary of kwargs to be passed to `create_train_op`.
    890       `update_ops` is removed, if present.
    891     gen_scope: A scope for the generator.
    892     dis_scope: A scope for the discriminator.
    893     check_for_unused_ops: A Python bool. If `True`, throw Exception if there are
    894       unused update ops.
    895 
    896   Returns:
    897     A 2-tuple of (generator update ops, discriminator train ops).
    898 
    899   Raises:
    900     ValueError: If there are update ops outside of the generator or
    901       discriminator scopes.
    902   """
    903   if 'update_ops' in kwargs:
    904     update_ops = set(kwargs['update_ops'])
    905     del kwargs['update_ops']
    906   else:
    907     update_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS))
    908 
    909   all_gen_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, gen_scope))
    910   all_dis_ops = set(ops.get_collection(ops.GraphKeys.UPDATE_OPS, dis_scope))
    911 
    912   if check_for_unused_ops:
    913     unused_ops = update_ops - all_gen_ops - all_dis_ops
    914     if unused_ops:
    915       raise ValueError('There are unused update ops: %s' % unused_ops)
    916 
    917   gen_update_ops = list(all_gen_ops & update_ops)
    918   dis_update_ops = list(all_dis_ops & update_ops)
    919 
    920   return gen_update_ops, dis_update_ops
    921 
    922 
    923 def gan_train_ops(
    924     model,
    925     loss,
    926     generator_optimizer,
    927     discriminator_optimizer,
    928     check_for_unused_update_ops=True,
    929     is_chief=True,
    930     # Optional args to pass directly to the `create_train_op`.
    931     **kwargs):
    932   """Returns GAN train ops.
    933 
    934   The highest-level call in TF-GAN. It is composed of functions that can also
    935   be called, should a user require more control over some part of the GAN
    936   training process.
    937 
    938   Args:
    939     model: A GANModel.
    940     loss: A GANLoss.
    941     generator_optimizer: The optimizer for generator updates.
    942     discriminator_optimizer: The optimizer for the discriminator updates.
    943     check_for_unused_update_ops: If `True`, throws an exception if there are
    944       update ops outside of the generator or discriminator scopes.
    945     is_chief: Specifies whether or not the training is being run by the primary
    946       replica during replica training.
    947     **kwargs: Keyword args to pass directly to
    948       `training.create_train_op` for both the generator and
    949       discriminator train op.
    950 
    951   Returns:
    952     A GANTrainOps tuple of (generator_train_op, discriminator_train_op) that can
    953     be used to train a generator/discriminator pair.
    954   """
    955   if isinstance(model, namedtuples.CycleGANModel):
    956     # Get and store all arguments other than model and loss from locals.
    957     # Contents of locals should not be modified, may not affect values. So make
    958     # a copy. https://docs.python.org/2/library/functions.html#locals.
    959     saved_params = dict(locals())
    960     saved_params.pop('model', None)
    961     saved_params.pop('loss', None)
    962     kwargs = saved_params.pop('kwargs', {})
    963     saved_params.update(kwargs)
    964     with ops.name_scope('cyclegan_x2y_train'):
    965       train_ops_x2y = gan_train_ops(model.model_x2y, loss.loss_x2y,
    966                                     **saved_params)
    967     with ops.name_scope('cyclegan_y2x_train'):
    968       train_ops_y2x = gan_train_ops(model.model_y2x, loss.loss_y2x,
    969                                     **saved_params)
    970     return namedtuples.GANTrainOps(
    971         (train_ops_x2y.generator_train_op, train_ops_y2x.generator_train_op),
    972         (train_ops_x2y.discriminator_train_op,
    973          train_ops_y2x.discriminator_train_op),
    974         training_util.get_or_create_global_step().assign_add(1))
    975 
    976   # Create global step increment op.
    977   global_step = training_util.get_or_create_global_step()
    978   global_step_inc = global_step.assign_add(1)
    979 
    980   # Get generator and discriminator update ops. We split them so that update
    981   # ops aren't accidentally run multiple times. For now, throw an error if
    982   # there are update ops that aren't associated with either the generator or
    983   # the discriminator. Might modify the `kwargs` dictionary.
    984   gen_update_ops, dis_update_ops = _get_update_ops(
    985       kwargs, model.generator_scope.name, model.discriminator_scope.name,
    986       check_for_unused_update_ops)
    987 
    988   # Get the sync hooks if these are needed.
    989   sync_hooks = []
    990 
    991   generator_global_step = None
    992   if isinstance(generator_optimizer,
    993                 sync_replicas_optimizer.SyncReplicasOptimizer):
    994     # TODO(joelshor): Figure out a way to get this work without including the
    995     # dummy global step in the checkpoint.
    996     # WARNING: Making this variable a local variable causes sync replicas to
    997     # hang forever.
    998     generator_global_step = variable_scope.get_variable(
    999         'dummy_global_step_generator',
   1000         shape=[],
   1001         dtype=global_step.dtype.base_dtype,
   1002         initializer=init_ops.zeros_initializer(),
   1003         trainable=False,
   1004         collections=[ops.GraphKeys.GLOBAL_VARIABLES])
   1005     gen_update_ops += [generator_global_step.assign(global_step)]
   1006     sync_hooks.append(generator_optimizer.make_session_run_hook(is_chief))
   1007   with ops.name_scope('generator_train'):
   1008     gen_train_op = training.create_train_op(
   1009         total_loss=loss.generator_loss,
   1010         optimizer=generator_optimizer,
   1011         variables_to_train=model.generator_variables,
   1012         global_step=generator_global_step,
   1013         update_ops=gen_update_ops,
   1014         **kwargs)
   1015 
   1016   discriminator_global_step = None
   1017   if isinstance(discriminator_optimizer,
   1018                 sync_replicas_optimizer.SyncReplicasOptimizer):
   1019     # See comment above `generator_global_step`.
   1020     discriminator_global_step = variable_scope.get_variable(
   1021         'dummy_global_step_discriminator',
   1022         shape=[],
   1023         dtype=global_step.dtype.base_dtype,
   1024         initializer=init_ops.zeros_initializer(),
   1025         trainable=False,
   1026         collections=[ops.GraphKeys.GLOBAL_VARIABLES])
   1027     dis_update_ops += [discriminator_global_step.assign(global_step)]
   1028     sync_hooks.append(discriminator_optimizer.make_session_run_hook(is_chief))
   1029   with ops.name_scope('discriminator_train'):
   1030     disc_train_op = training.create_train_op(
   1031         total_loss=loss.discriminator_loss,
   1032         optimizer=discriminator_optimizer,
   1033         variables_to_train=model.discriminator_variables,
   1034         global_step=discriminator_global_step,
   1035         update_ops=dis_update_ops,
   1036         **kwargs)
   1037 
   1038   return namedtuples.GANTrainOps(gen_train_op, disc_train_op, global_step_inc,
   1039                                  sync_hooks)
   1040 
   1041 
   1042 # TODO(joelshor): Implement a dynamic GAN train loop, as in `Real-Time Adaptive
   1043 # Image Compression` (https://arxiv.org/abs/1705.05823)
   1044 class RunTrainOpsHook(session_run_hook.SessionRunHook):
   1045   """A hook to run train ops a fixed number of times."""
   1046 
   1047   def __init__(self, train_ops, train_steps):
   1048     """Run train ops a certain number of times.
   1049 
   1050     Args:
   1051       train_ops: A train op or iterable of train ops to run.
   1052       train_steps: The number of times to run the op(s).
   1053     """
   1054     if not isinstance(train_ops, (list, tuple)):
   1055       train_ops = [train_ops]
   1056     self._train_ops = train_ops
   1057     self._train_steps = train_steps
   1058 
   1059   def before_run(self, run_context):
   1060     for _ in range(self._train_steps):
   1061       run_context.session.run(self._train_ops)
   1062 
   1063 
   1064 def get_sequential_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
   1065   """Returns a hooks function for sequential GAN training.
   1066 
   1067   Args:
   1068     train_steps: A `GANTrainSteps` tuple that determines how many generator
   1069       and discriminator training steps to take.
   1070 
   1071   Returns:
   1072     A function that takes a GANTrainOps tuple and returns a list of hooks.
   1073   """
   1074 
   1075   def get_hooks(train_ops):
   1076     generator_hook = RunTrainOpsHook(train_ops.generator_train_op,
   1077                                      train_steps.generator_train_steps)
   1078     discriminator_hook = RunTrainOpsHook(train_ops.discriminator_train_op,
   1079                                          train_steps.discriminator_train_steps)
   1080     return [generator_hook, discriminator_hook] + list(train_ops.train_hooks)
   1081 
   1082   return get_hooks
   1083 
   1084 
   1085 def _num_joint_steps(train_steps):
   1086   g_steps = train_steps.generator_train_steps
   1087   d_steps = train_steps.discriminator_train_steps
   1088   # Get the number of each type of step that should be run.
   1089   num_d_and_g_steps = min(g_steps, d_steps)
   1090   num_g_steps = g_steps - num_d_and_g_steps
   1091   num_d_steps = d_steps - num_d_and_g_steps
   1092 
   1093   return num_d_and_g_steps, num_g_steps, num_d_steps
   1094 
   1095 
   1096 def get_joint_train_hooks(train_steps=namedtuples.GANTrainSteps(1, 1)):
   1097   """Returns a hooks function for joint GAN training.
   1098 
   1099   When using these train hooks, IT IS RECOMMENDED TO USE `use_locking=True` ON
   1100   ALL OPTIMIZERS TO AVOID RACE CONDITIONS.
   1101 
   1102   The order of steps taken is:
   1103   1) Combined generator and discriminator steps
   1104   2) Generator only steps, if any remain
   1105   3) Discriminator only steps, if any remain
   1106 
   1107   **NOTE**: Unlike `get_sequential_train_hooks`, this method performs updates
   1108   for the generator and discriminator simultaneously whenever possible. This
   1109   reduces the number of `tf.Session` calls, and can also change the training
   1110   semantics.
   1111 
   1112   To illustrate the difference look at the following example:
   1113 
   1114   `train_steps=namedtuples.GANTrainSteps(3, 5)` will cause
   1115   `get_sequential_train_hooks` to make 8 session calls:
   1116     1) 3 generator steps
   1117     2) 5 discriminator steps
   1118 
   1119   In contrast, `get_joint_train_steps` will make 5 session calls:
   1120   1) 3 generator + discriminator steps
   1121   2) 2 discriminator steps
   1122 
   1123   Args:
   1124     train_steps: A `GANTrainSteps` tuple that determines how many generator
   1125       and discriminator training steps to take.
   1126 
   1127   Returns:
   1128     A function that takes a GANTrainOps tuple and returns a list of hooks.
   1129   """
   1130   num_d_and_g_steps, num_g_steps, num_d_steps = _num_joint_steps(train_steps)
   1131 
   1132   def get_hooks(train_ops):
   1133     g_op = train_ops.generator_train_op
   1134     d_op = train_ops.discriminator_train_op
   1135 
   1136     joint_hook = RunTrainOpsHook([g_op, d_op], num_d_and_g_steps)
   1137     g_hook = RunTrainOpsHook(g_op, num_g_steps)
   1138     d_hook = RunTrainOpsHook(d_op, num_d_steps)
   1139 
   1140     return [joint_hook, g_hook, d_hook] + list(train_ops.train_hooks)
   1141 
   1142   return get_hooks
   1143 
   1144 
   1145 # TODO(joelshor): This function currently returns the global step. Find a
   1146 # good way for it to return the generator, discriminator, and final losses.
   1147 def gan_train(train_ops,
   1148               logdir,
   1149               get_hooks_fn=get_sequential_train_hooks(),
   1150               master='',
   1151               is_chief=True,
   1152               scaffold=None,
   1153               hooks=None,
   1154               chief_only_hooks=None,
   1155               save_checkpoint_secs=600,
   1156               save_summaries_steps=100,
   1157               config=None):
   1158   """A wrapper around `contrib.training.train` that uses GAN hooks.
   1159 
   1160   Args:
   1161     train_ops: A GANTrainOps named tuple.
   1162     logdir: The directory where the graph and checkpoints are saved.
   1163     get_hooks_fn: A function that takes a GANTrainOps tuple and returns a list
   1164       of hooks.
   1165     master: The URL of the master.
   1166     is_chief: Specifies whether or not the training is being run by the primary
   1167       replica during replica training.
   1168     scaffold: An tf.train.Scaffold instance.
   1169     hooks: List of `tf.train.SessionRunHook` callbacks which are run inside the
   1170       training loop.
   1171     chief_only_hooks: List of `tf.train.SessionRunHook` instances which are run
   1172       inside the training loop for the chief trainer only.
   1173     save_checkpoint_secs: The frequency, in seconds, that a checkpoint is saved
   1174       using a default checkpoint saver. If `save_checkpoint_secs` is set to
   1175       `None`, then the default checkpoint saver isn't used.
   1176     save_summaries_steps: The frequency, in number of global steps, that the
   1177       summaries are written to disk using a default summary saver. If
   1178       `save_summaries_steps` is set to `None`, then the default summary saver
   1179       isn't used.
   1180     config: An instance of `tf.ConfigProto`.
   1181 
   1182   Returns:
   1183     Output of the call to `training.train`.
   1184   """
   1185   new_hooks = get_hooks_fn(train_ops)
   1186   if hooks is not None:
   1187     hooks = list(hooks) + list(new_hooks)
   1188   else:
   1189     hooks = new_hooks
   1190   return training.train(
   1191       train_ops.global_step_inc_op,
   1192       logdir,
   1193       master=master,
   1194       is_chief=is_chief,
   1195       scaffold=scaffold,
   1196       hooks=hooks,
   1197       chief_only_hooks=chief_only_hooks,
   1198       save_checkpoint_secs=save_checkpoint_secs,
   1199       save_summaries_steps=save_summaries_steps,
   1200       config=config)
   1201 
   1202 
   1203 def get_sequential_train_steps(train_steps=namedtuples.GANTrainSteps(1, 1)):
   1204   """Returns a thin wrapper around slim.learning.train_step, for GANs.
   1205 
   1206   This function is to provide support for the Supervisor. For new code, please
   1207   use `MonitoredSession` and `get_sequential_train_hooks`.
   1208 
   1209   Args:
   1210     train_steps: A `GANTrainSteps` tuple that determines how many generator
   1211       and discriminator training steps to take.
   1212 
   1213   Returns:
   1214     A function that can be used for `train_step_fn` for GANs.
   1215   """
   1216 
   1217   def sequential_train_steps(sess, train_ops, global_step, train_step_kwargs):
   1218     """A thin wrapper around slim.learning.train_step, for GANs.
   1219 
   1220     Args:
   1221       sess: A Tensorflow session.
   1222       train_ops: A GANTrainOps tuple of train ops to run.
   1223       global_step: The global step.
   1224       train_step_kwargs: Dictionary controlling `train_step` behavior.
   1225 
   1226     Returns:
   1227       A scalar final loss and a bool whether or not the train loop should stop.
   1228     """
   1229     # Only run `should_stop` at the end, if required. Make a local copy of
   1230     # `train_step_kwargs`, if necessary, so as not to modify the caller's
   1231     # dictionary.
   1232     should_stop_op, train_kwargs = None, train_step_kwargs
   1233     if 'should_stop' in train_step_kwargs:
   1234       should_stop_op = train_step_kwargs['should_stop']
   1235       train_kwargs = train_step_kwargs.copy()
   1236       del train_kwargs['should_stop']
   1237 
   1238     # Run generator training steps.
   1239     gen_loss = 0
   1240     for _ in range(train_steps.generator_train_steps):
   1241       cur_gen_loss, _ = slim_learning.train_step(
   1242           sess, train_ops.generator_train_op, global_step, train_kwargs)
   1243       gen_loss += cur_gen_loss
   1244 
   1245     # Run discriminator training steps.
   1246     dis_loss = 0
   1247     for _ in range(train_steps.discriminator_train_steps):
   1248       cur_dis_loss, _ = slim_learning.train_step(
   1249           sess, train_ops.discriminator_train_op, global_step, train_kwargs)
   1250       dis_loss += cur_dis_loss
   1251 
   1252     sess.run(train_ops.global_step_inc_op)
   1253 
   1254     # Run the `should_stop` op after the global step has been incremented, so
   1255     # that the `should_stop` aligns with the proper `global_step` count.
   1256     if should_stop_op is not None:
   1257       should_stop = sess.run(should_stop_op)
   1258     else:
   1259       should_stop = False
   1260 
   1261     return gen_loss + dis_loss, should_stop
   1262 
   1263   return sequential_train_steps
   1264 
   1265 
   1266 # Helpers
   1267 
   1268 
   1269 def _convert_tensor_or_l_or_d(tensor_or_l_or_d):
   1270   """Convert input, list of inputs, or dictionary of inputs to Tensors."""
   1271   if isinstance(tensor_or_l_or_d, (list, tuple)):
   1272     return [ops.convert_to_tensor(x) for x in tensor_or_l_or_d]
   1273   elif isinstance(tensor_or_l_or_d, dict):
   1274     return {k: ops.convert_to_tensor(v) for k, v in tensor_or_l_or_d.items()}
   1275   else:
   1276     return ops.convert_to_tensor(tensor_or_l_or_d)
   1277 
   1278 
   1279 def _validate_distributions(distributions_l, noise_l):
   1280   if not isinstance(distributions_l, (tuple, list)):
   1281     raise ValueError('`predicted_distributions` must be a list. Instead, found '
   1282                      '%s.' % type(distributions_l))
   1283   if len(distributions_l) != len(noise_l):
   1284     raise ValueError('Length of `predicted_distributions` %i must be the same '
   1285                      'as the length of structured noise %i.' %
   1286                      (len(distributions_l), len(noise_l)))
   1287 
   1288 
   1289 def _validate_acgan_discriminator_outputs(discriminator_output):
   1290   try:
   1291     a, b = discriminator_output
   1292   except (TypeError, ValueError):
   1293     raise TypeError(
   1294         'A discriminator function for ACGAN must output a tuple '
   1295         'consisting of (discrimination logits, classification logits).')
   1296   return a, b
   1297 
   1298 
   1299 def _generate_stargan_random_domain_target(batch_size, num_domains):
   1300   """Generate random domain label.
   1301 
   1302   Args:
   1303     batch_size: (int) Number of random domain label.
   1304     num_domains: (int) Number of domains representing with the label.
   1305 
   1306   Returns:
   1307     Tensor of shape (batch_size, num_domains) representing random label.
   1308   """
   1309   domain_idx = random_ops.random_uniform(
   1310       [batch_size], minval=0, maxval=num_domains, dtype=dtypes.int32)
   1311 
   1312   return array_ops.one_hot(domain_idx, num_domains)
   1313