     15 """Model evaluation tools for TFGAN.
     17 These methods come from https://arxiv.org/abs/1606.03498 and
     18 https://arxiv.org/abs/1706.08500.
     20 NOTE: This implementation uses the same weights as in
     21 https://github.com/openai/improved-gan/blob/master/inception_score/model.py,
     22 but is more numerically stable and is an unbiased estimator of the true
     23 Inception score even when splitting the inputs into batches.
     24 """
     26 from __future__ import absolute_import
     27 from __future__ import division
     28 from __future__ import print_function
     30 import functools
     31 import os
     32 import sys
     33 import tarfile
     35 from six.moves import urllib
     37 from tensorflow.contrib.layers.python.layers import layers
     38 from tensorflow.core.framework import graph_pb2
     39 from tensorflow.python.framework import dtypes
     40 from tensorflow.python.framework import importer
     41 from tensorflow.python.framework import ops
     42 from tensorflow.python.ops import array_ops
     43 from tensorflow.python.ops import functional_ops
     44 from tensorflow.python.ops import image_ops
     45 from tensorflow.python.ops import linalg_ops
     46 from tensorflow.python.ops import math_ops
     47 from tensorflow.python.ops import nn_ops
     48 from tensorflow.python.platform import gfile
     49 from tensorflow.python.platform import resource_loader
     52 __all__ = [
     53     'get_graph_def_from_disk',
     54     'get_graph_def_from_resource',
     55     'get_graph_def_from_url_tarball',
     56     'preprocess_image',
     57     'run_image_classifier',
     58     'run_inception',
     59     'inception_score',
     60     'classifier_score',
     61     'classifier_score_from_logits',
     62     'frechet_inception_distance',
     63     'frechet_classifier_distance',
     64     'frechet_classifier_distance_from_activations',
     66 ]
     69 INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'
     70 INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb'
     71 INCEPTION_INPUT = 'Mul:0'
     72 INCEPTION_OUTPUT = 'logits:0'
     73 INCEPTION_FINAL_POOL = 'pool_3:0'
     77 def _validate_images(images, image_size):
     78   images = ops.convert_to_tensor(images)
     79   images.shape.with_rank(4)
     80   images.shape.assert_is_compatible_with(
     81       [None, image_size, image_size, None])
     82   return images
     85 def _symmetric_matrix_square_root(mat, eps=1e-10):
     86   """Compute square root of a symmetric matrix.
     88   Note that this is different from an elementwise square root. We want to
     89   compute M' where M' = sqrt(mat) such that M' * M' = mat.
     91   Also note that this method **only** works for symmetric matrices.
     93   Args:
     94     mat: Matrix to take the square root of.
     95     eps: Small epsilon such that any element less than eps will not be square
     96       rooted to guard against numerical instability.
     98   Returns:
     99     Matrix square root of mat.
    100   """
    101   # Unlike numpy, tensorflow's return order is (s, u, v)
    102   s, u, v = linalg_ops.svd(mat)
    103   # sqrt is unstable around 0, just use 0 in such case
    104   si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s))
    105   # Note that the v returned by Tensorflow is v = V
    106   # (when referencing the equation A = U S V^T)
    107   # This is unlike Numpy which returns v = V^T
    108   return math_ops.matmul(
    109       math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True)
    112 def preprocess_image(
    113     images, height=INCEPTION_DEFAULT_IMAGE_SIZE,
    114     width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None):
    115   """Prepare a batch of images for evaluation.
    117   This is the preprocessing portion of the graph from
    118   http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz.
    120   Note that it expects Tensors in [0, 255]. This function maps pixel values to
    121   [-1, 1] and resizes to match the InceptionV1 network.
    123   Args:
    124     images: 3-D or 4-D Tensor of images. Values are in [0, 255].
    125     height: Integer. Height of resized output image.
    126     width: Integer. Width of resized output image.
    127     scope: Optional scope for name_scope.
    129   Returns:
    130     3-D or 4-D float Tensor of prepared image(s). Values are in [-1, 1].
    131   """
    132   is_single = images.shape.ndims == 3
    133   with ops.name_scope(scope, 'preprocess', [images, height, width]):
    134     if not images.dtype.is_floating:
    135       images = math_ops.to_float(images)
    136     if is_single:
    137       images = array_ops.expand_dims(images, axis=0)
    138     resized = image_ops.resize_bilinear(images, [height, width])
    139     resized = (resized - 128.0) / 128.0
    140     if is_single:
    141       resized = array_ops.squeeze(resized, axis=0)
    142     return resized
    145 def _kl_divergence(p, p_logits, q):
    146   """Computes the Kullback-Liebler divergence between p and q.
    148   This function uses p's logits in some places to improve numerical stability.
    150   Specifically:
    152   KL(p || q) = sum[ p * log(p / q) ]
    153     = sum[ p * ( log(p)                - log(q) ) ]
    154     = sum[ p * ( log_softmax(p_logits) - log(q) ) ]
    156   Args:
    157     p: A 2-D floating-point Tensor p_ij, where `i` corresponds to the minibatch
    158       example and `j` corresponds to the probability of being in class `j`.
    159     p_logits: A 2-D floating-point Tensor corresponding to logits for `p`.
    160     q: A 1-D floating-point Tensor, where q_j corresponds to the probability
    161       of class `j`.
    163   Returns:
    164     KL divergence between two distributions. Output dimension is 1D, one entry
    165     per distribution in `p`.
    167   Raises:
    168     ValueError: If any of the inputs aren't floating-point.
    169     ValueError: If p or p_logits aren't 2D.
    170     ValueError: If q isn't 1D.
    171   """
    172   for tensor in [p, p_logits, q]:
    173     if not tensor.dtype.is_floating:
    174       raise ValueError('Input %s must be floating type.', tensor.name)
    175   p.shape.assert_has_rank(2)
    176   p_logits.shape.assert_has_rank(2)
    177   q.shape.assert_has_rank(1)
    178   return math_ops.reduce_sum(
    179       p * (nn_ops.log_softmax(p_logits) - math_ops.log(q)), axis=1)
    182 def get_graph_def_from_disk(filename):
    183   """Get a GraphDef proto from a disk location."""
    184   with gfile.FastGFile(filename, 'rb') as f:
    185     return graph_pb2.GraphDef.FromString(f.read())
    188 def get_graph_def_from_resource(filename):
    189   """Get a GraphDef proto from within a .par file."""
    190   return graph_pb2.GraphDef.FromString(resource_loader.load_resource(filename))
    193 def get_graph_def_from_url_tarball(url, filename, tar_filename=None):
    194   """Get a GraphDef proto from a tarball on the web.
    196   Args:
    197     url: Web address of tarball
    198     filename: Filename of graph definition within tarball
    199     tar_filename: Temporary download filename (None = always download)
    201   Returns:
    202     A GraphDef loaded from a file in the downloaded tarball.
    203   """
    204   if not (tar_filename and os.path.exists(tar_filename)):
    206     def _progress(count, block_size, total_size):
    207       sys.stdout.write('\r>> Downloading %s %.1f%%' %
    208                        (url,
    209                         float(count * block_size) / float(total_size) * 100.0))
    210       sys.stdout.flush()
    212     tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress)
    213   with tarfile.open(tar_filename, 'r:gz') as tar:
    214     proto_str = tar.extractfile(filename).read()
    215   return graph_pb2.GraphDef.FromString(proto_str)
    218 def _default_graph_def_fn():
    219   return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH,
    220                                         os.path.basename(INCEPTION_URL))
    223 def run_inception(images,
    224                   graph_def=None,
    225                   default_graph_def_fn=_default_graph_def_fn,
    226                   image_size=INCEPTION_DEFAULT_IMAGE_SIZE,
    227                   input_tensor=INCEPTION_INPUT,
    228                   output_tensor=INCEPTION_OUTPUT):
    229   """Run images through a pretrained Inception classifier.
    231   Args:
    232     images: Input tensors. Must be [batch, height, width, channels]. Input shape
    233       and values must be in [-1, 1], which can be achieved using
    234       `preprocess_image`.
    235     graph_def: A GraphDef proto of a pretrained Inception graph. If `None`,
    236       call `default_graph_def_fn` to get GraphDef.
    237     default_graph_def_fn: A function that returns a GraphDef. Used if
    238       `graph_def` is `None. By default, returns a pretrained InceptionV3 graph.
    239     image_size: Required image width and height. See unit tests for the default
    240       values.
    241     input_tensor: Name of input Tensor.
    242     output_tensor: Name or list of output Tensors. This function will compute
    243       activations at the specified layer. Examples include INCEPTION_V3_OUTPUT
    244       and INCEPTION_V3_FINAL_POOL which would result in this function computing
    245       the final logits or the penultimate pooling layer.
    247   Returns:
    248     Tensor or Tensors corresponding to computed `output_tensor`.
    250   Raises:
    251     ValueError: If images are not the correct size.
    252     ValueError: If neither `graph_def` nor `default_graph_def_fn` are provided.
    253   """
    254   images = _validate_images(images, image_size)
    256   if graph_def is None:
    257     if default_graph_def_fn is None:
    258       raise ValueError('If `graph_def` is `None`, must provide '
    259                        '`default_graph_def_fn`.')
    260     graph_def = default_graph_def_fn()
    262   activations = run_image_classifier(images, graph_def, input_tensor,
    263                                      output_tensor)
    264   if isinstance(activations, list):
    265     for i, activation in enumerate(activations):
    266       if array_ops.rank(activation) != 2:
    267         activations[i] = layers.flatten(activation)
    268   else:
    269     if array_ops.rank(activations) != 2:
    270       activations = layers.flatten(activations)
    272   return activations
    275 def run_image_classifier(tensor, graph_def, input_tensor,
    276                          output_tensor, scope='RunClassifier'):
    277   """Runs a network from a frozen graph.
    279   Args:
    280     tensor: An Input tensor.
    281     graph_def: A GraphDef proto.
    282     input_tensor: Name of input tensor in graph def.
    283     output_tensor: A tensor name or list of tensor names in graph def.
    284     scope: Name scope for classifier.
    286   Returns:
    287     Classifier output if `output_tensor` is a string, or a list of outputs if
    288     `output_tensor` is a list.
    290   Raises:
    291     ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def.
    292   """
    293   input_map = {input_tensor: tensor}
    294   is_singleton = isinstance(output_tensor, str)
    295   if is_singleton:
    296     output_tensor = [output_tensor]
    297   classifier_outputs = importer.import_graph_def(
    298       graph_def, input_map, output_tensor, name=scope)
    299   if is_singleton:
    300     classifier_outputs = classifier_outputs[0]
    302   return classifier_outputs
    305 def classifier_score(images, classifier_fn, num_batches=1):
    306   """Classifier score for evaluating a conditional generative model.
    308   This is based on the Inception Score, but for an arbitrary classifier.
    310   This technique is described in detail in https://arxiv.org/abs/1606.03498. In
    311   summary, this function calculates
    313   exp( E[ KL(p(y|x) || p(y)) ] )
    315   which captures how different the network's classification prediction is from
    316   the prior distribution over classes.
    318   NOTE: This function consumes images, computes their logits, and then
    319   computes the classifier score. If you would like to precompute many logits for
    320   large batches, use clasifier_score_from_logits(), which this method also
    321   uses.
    323   Args:
    324     images: Images to calculate the classifier score for.
    325     classifier_fn: A function that takes images and produces logits based on a
    326       classifier.
    327     num_batches: Number of batches to split `generated_images` in to in order to
    328       efficiently run them through the classifier network.
    330   Returns:
    331     The classifier score. A floating-point scalar of the same type as the output
    332     of `classifier_fn`.
    333   """
    334   generated_images_list = array_ops.split(
    335       images, num_or_size_splits=num_batches)
    337   # Compute the classifier splits using the memory-efficient `map_fn`.
    338   logits = functional_ops.map_fn(
    339       fn=classifier_fn,
    340       elems=array_ops.stack(generated_images_list),
    341       parallel_iterations=1,
    342       back_prop=False,
    343       swap_memory=True,
    344       name='RunClassifier')
    345   logits = array_ops.concat(array_ops.unstack(logits), 0)
    347   return classifier_score_from_logits(logits)
    350 def classifier_score_from_logits(logits):
    351   """Classifier score for evaluating a generative model from logits.
    353   This method computes the classifier score for a set of logits. This can be
    354   used independently of the classifier_score() method, especially in the case
    355   of using large batches during evaluation where we would like precompute all
    356   of the logits before computing the classifier score.
    358   This technique is described in detail in https://arxiv.org/abs/1606.03498. In
    359   summary, this function calculates:
    361   exp( E[ KL(p(y|x) || p(y)) ] )
    363   which captures how different the network's classification prediction is from
    364   the prior distribution over classes.
    366   Args:
    367     logits: Precomputed 2D tensor of logits that will be used to
    368       compute the classifier score.
    370   Returns:
    371     The classifier score. A floating-point scalar of the same type as the output
    372     of `logits`.
    373   """
    374   logits.shape.assert_has_rank(2)
    376   # Use maximum precision for best results.
    377   logits_dtype = logits.dtype
    378   if logits_dtype != dtypes.float64:
    379     logits = math_ops.to_double(logits)
    381   p = nn_ops.softmax(logits)
    382   q = math_ops.reduce_mean(p, axis=0)
    383   kl = _kl_divergence(p, logits, q)
    384   kl.shape.assert_has_rank(1)
    385   log_score = math_ops.reduce_mean(kl)
    386   final_score = math_ops.exp(log_score)
    388   if logits_dtype != dtypes.float64:
    389     final_score = math_ops.cast(final_score, logits_dtype)
    391   return final_score
    394 inception_score = functools.partial(
    395     classifier_score,
    396     classifier_fn=functools.partial(
    397         run_inception, output_tensor=INCEPTION_OUTPUT))
    400 def trace_sqrt_product(sigma, sigma_v):
    401   """Find the trace of the positive sqrt of product of covariance matrices.
    403   '_symmetric_matrix_square_root' only works for symmetric matrices, so we
    404   cannot just take _symmetric_matrix_square_root(sigma * sigma_v).
    405   ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily).
    407   Let sigma = A A so A = sqrt(sigma), and sigma_v = B B.
    408   We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B))
    409   Note the following properties:
    410   (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1)
    411      => eigenvalues(A A B B) = eigenvalues (A B B A)
    412   (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2))
    413      => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A))
    414   (iii) forall M: trace(M) = sum(eigenvalues(M))
    415      => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v)))
    416                                    = sum(sqrt(eigenvalues(A B B A)))
    417                                    = sum(eigenvalues(sqrt(A B B A)))
    418                                    = trace(sqrt(A B B A))
    419                                    = trace(sqrt(A sigma_v A))
    420   A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can**
    421   use the _symmetric_matrix_square_root function to find the roots of these
    422   matrices.
    424   Args:
    425     sigma: a square, symmetric, real, positive semi-definite covariance matrix
    426     sigma_v: same as sigma
    428   Returns:
    429     The trace of the positive square root of sigma*sigma_v
    430   """
    432   # Note sqrt_sigma is called "A" in the proof above
    433   sqrt_sigma = _symmetric_matrix_square_root(sigma)
    435   # This is sqrt(A sigma_v A) above
    436   sqrt_a_sigmav_a = math_ops.matmul(
    437       sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma))
    439   return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
    442 def frechet_classifier_distance(real_images,
    443                                 generated_images,
    444                                 classifier_fn,
    445                                 num_batches=1):
    446   """Classifier distance for evaluating a generative model.
    448   This is based on the Frechet Inception distance, but for an arbitrary
    449   classifier.
    451   This technique is described in detail in https://arxiv.org/abs/1706.08500.
    452   Given two Gaussian distribution with means m and m_w and covariance matrices
    453   C and C_w, this function calcuates
    455   |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
    457   which captures how different the distributions of real images and generated
    458   images (or more accurately, their visual features) are. Note that unlike the
    459   Inception score, this is a true distance and utilizes information about real
    460   world images.
    462   Note that when computed using sample means and sample covariance matrices,
    463   Frechet distance is biased. It is more biased for small sample sizes. (e.g.
    464   even if the two distributions are the same, for a small sample size, the
    465   expected Frechet distance is large). It is important to use the same
    466   sample size to compute frechet classifier distance when comparing two
    467   generative models.
    469   NOTE: This function consumes images, computes their activations, and then
    470   computes the classifier score. If you would like to precompute many
    471   activations for real and generated images for large batches, please use
    472   frechet_clasifier_distance_from_activations(), which this method also uses.
    474   Args:
    475     real_images: Real images to use to compute Frechet Inception distance.
    476     generated_images: Generated images to use to compute Frechet Inception
    477       distance.
    478     classifier_fn: A function that takes images and produces activations
    479       based on a classifier.
    480     num_batches: Number of batches to split images in to in order to
    481       efficiently run them through the classifier network.
    483   Returns:
    484     The Frechet Inception distance. A floating-point scalar of the same type
    485     as the output of `classifier_fn`.
    486   """
    488   real_images_list = array_ops.split(
    489       real_images, num_or_size_splits=num_batches)
    490   generated_images_list = array_ops.split(
    491       generated_images, num_or_size_splits=num_batches)
    493   imgs = array_ops.stack(real_images_list + generated_images_list)
    495   # Compute the activations using the memory-efficient `map_fn`.
    496   activations = functional_ops.map_fn(
    497       fn=classifier_fn,
    498       elems=imgs,
    499       parallel_iterations=1,
    500       back_prop=False,
    501       swap_memory=True,
    502       name='RunClassifier')
    504   # Split the activations by the real and generated images.
    505   real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
    507   # Ensure the activations have the right shapes.
    508   real_a = array_ops.concat(array_ops.unstack(real_a), 0)
    509   gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
    511   return frechet_classifier_distance_from_activations(real_a, gen_a)
    514 def frechet_classifier_distance_from_activations(
    515     real_activations, generated_activations):
    516   """Classifier distance for evaluating a generative model from activations.
    518   This methods computes the Frechet classifier distance from activations of
    519   real images and generated images. This can be used independently of the
    520   frechet_classifier_distance() method, especially in the case of using large
    521   batches during evaluation where we would like precompute all of the
    522   activations before computing the classifier distance.
    524   This technique is described in detail in https://arxiv.org/abs/1706.08500.
    525   Given two Gaussian distribution with means m and m_w and covariance matrices
    526   C and C_w, this function calcuates
    528   |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
    530   which captures how different the distributions of real images and generated
    531   images (or more accurately, their visual features) are. Note that unlike the
    532   Inception score, this is a true distance and utilizes information about real
    533   world images.
    535   Args:
    536     real_activations: 2D Tensor containing activations of real data. Shape is
    537       [batch_size, activation_size].
    538     generated_activations: 2D Tensor containing activations of generated data.
    539       Shape is [batch_size, activation_size].
    541   Returns:
    542    The Frechet Inception distance. A floating-point scalar of the same type
    543    as the output of the activations.
    545   """
    546   real_activations.shape.assert_has_rank(2)
    547   generated_activations.shape.assert_has_rank(2)
    549   activations_dtype = real_activations.dtype
    550   if activations_dtype != dtypes.float64:
    551     real_activations = math_ops.to_double(real_activations)
    552     generated_activations = math_ops.to_double(generated_activations)
    554   # Compute mean and covariance matrices of activations.
    555   m = math_ops.reduce_mean(real_activations, 0)
    556   m_v = math_ops.reduce_mean(generated_activations, 0)
    557   num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])
    559   # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
    560   real_centered = real_activations - m
    561   sigma = math_ops.matmul(
    562       real_centered, real_centered, transpose_a=True) / (num_examples - 1)
    564   gen_centered = generated_activations - m_v
    565   sigma_v = math_ops.matmul(
    566       gen_centered, gen_centered, transpose_a=True) / (num_examples - 1)
    568   # Find the Tr(sqrt(sigma sigma_v)) component of FID
    569   sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)
    571   # Compute the two components of FID.
    573   # First the covariance component.
    574   # Here, note that trace(A + B) = trace(A) + trace(B)
    575   trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component
    577   # Next the distance between means.
    578   mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
    579   fid = trace + mean
    580   if activations_dtype != dtypes.float64:
    581     fid = math_ops.cast(fid, activations_dtype)
    583   return fid
    586 frechet_inception_distance = functools.partial(
    587     frechet_classifier_distance,
    588     classifier_fn=functools.partial(
    589         run_inception, output_tensor=INCEPTION_FINAL_POOL))