Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Model evaluation tools for TFGAN.
     16 
     17 These methods come from https://arxiv.org/abs/1606.03498 and
     18 https://arxiv.org/abs/1706.08500.
     19 
     20 NOTE: This implementation uses the same weights as in
     21 https://github.com/openai/improved-gan/blob/master/inception_score/model.py,
     22 but is more numerically stable and is an unbiased estimator of the true
     23 Inception score even when splitting the inputs into batches.
     24 """
     25 
     26 from __future__ import absolute_import
     27 from __future__ import division
     28 from __future__ import print_function
     29 
     30 import functools
     31 import os
     32 import sys
     33 import tarfile
     34 
     35 from six.moves import urllib
     36 
     37 from tensorflow.contrib.layers.python.layers import layers
     38 from tensorflow.core.framework import graph_pb2
     39 from tensorflow.python.framework import dtypes
     40 from tensorflow.python.framework import importer
     41 from tensorflow.python.framework import ops
     42 from tensorflow.python.ops import array_ops
     43 from tensorflow.python.ops import functional_ops
     44 from tensorflow.python.ops import image_ops
     45 from tensorflow.python.ops import linalg_ops
     46 from tensorflow.python.ops import math_ops
     47 from tensorflow.python.ops import nn_ops
     48 from tensorflow.python.platform import gfile
     49 from tensorflow.python.platform import resource_loader
     50 
     51 
     52 __all__ = [
     53     'get_graph_def_from_disk',
     54     'get_graph_def_from_resource',
     55     'get_graph_def_from_url_tarball',
     56     'preprocess_image',
     57     'run_image_classifier',
     58     'run_inception',
     59     'inception_score',
     60     'classifier_score',
     61     'classifier_score_from_logits',
     62     'frechet_inception_distance',
     63     'frechet_classifier_distance',
     64     'frechet_classifier_distance_from_activations',
     65     'INCEPTION_DEFAULT_IMAGE_SIZE',
     66 ]
     67 
     68 
     69 INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz'
     70 INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb'
     71 INCEPTION_INPUT = 'Mul:0'
     72 INCEPTION_OUTPUT = 'logits:0'
     73 INCEPTION_FINAL_POOL = 'pool_3:0'
     74 INCEPTION_DEFAULT_IMAGE_SIZE = 299
     75 
     76 
     77 def _validate_images(images, image_size):
     78   images = ops.convert_to_tensor(images)
     79   images.shape.with_rank(4)
     80   images.shape.assert_is_compatible_with(
     81       [None, image_size, image_size, None])
     82   return images
     83 
     84 
     85 def _symmetric_matrix_square_root(mat, eps=1e-10):
     86   """Compute square root of a symmetric matrix.
     87 
     88   Note that this is different from an elementwise square root. We want to
     89   compute M' where M' = sqrt(mat) such that M' * M' = mat.
     90 
     91   Also note that this method **only** works for symmetric matrices.
     92 
     93   Args:
     94     mat: Matrix to take the square root of.
     95     eps: Small epsilon such that any element less than eps will not be square
     96       rooted to guard against numerical instability.
     97 
     98   Returns:
     99     Matrix square root of mat.
    100   """
    101   # Unlike numpy, tensorflow's return order is (s, u, v)
    102   s, u, v = linalg_ops.svd(mat)
    103   # sqrt is unstable around 0, just use 0 in such case
    104   si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s))
    105   # Note that the v returned by Tensorflow is v = V
    106   # (when referencing the equation A = U S V^T)
    107   # This is unlike Numpy which returns v = V^T
    108   return math_ops.matmul(
    109       math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True)
    110 
    111 
    112 def preprocess_image(
    113     images, height=INCEPTION_DEFAULT_IMAGE_SIZE,
    114     width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None):
    115   """Prepare a batch of images for evaluation.
    116 
    117   This is the preprocessing portion of the graph from
    118   http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz.
    119 
    120   Note that it expects Tensors in [0, 255]. This function maps pixel values to
    121   [-1, 1] and resizes to match the InceptionV1 network.
    122 
    123   Args:
    124     images: 3-D or 4-D Tensor of images. Values are in [0, 255].
    125     height: Integer. Height of resized output image.
    126     width: Integer. Width of resized output image.
    127     scope: Optional scope for name_scope.
    128 
    129   Returns:
    130     3-D or 4-D float Tensor of prepared image(s). Values are in [-1, 1].
    131   """
    132   is_single = images.shape.ndims == 3
    133   with ops.name_scope(scope, 'preprocess', [images, height, width]):
    134     if not images.dtype.is_floating:
    135       images = math_ops.to_float(images)
    136     if is_single:
    137       images = array_ops.expand_dims(images, axis=0)
    138     resized = image_ops.resize_bilinear(images, [height, width])
    139     resized = (resized - 128.0) / 128.0
    140     if is_single:
    141       resized = array_ops.squeeze(resized, axis=0)
    142     return resized
    143 
    144 
    145 def _kl_divergence(p, p_logits, q):
    146   """Computes the Kullback-Liebler divergence between p and q.
    147 
    148   This function uses p's logits in some places to improve numerical stability.
    149 
    150   Specifically:
    151 
    152   KL(p || q) = sum[ p * log(p / q) ]
    153     = sum[ p * ( log(p)                - log(q) ) ]
    154     = sum[ p * ( log_softmax(p_logits) - log(q) ) ]
    155 
    156   Args:
    157     p: A 2-D floating-point Tensor p_ij, where `i` corresponds to the minibatch
    158       example and `j` corresponds to the probability of being in class `j`.
    159     p_logits: A 2-D floating-point Tensor corresponding to logits for `p`.
    160     q: A 1-D floating-point Tensor, where q_j corresponds to the probability
    161       of class `j`.
    162 
    163   Returns:
    164     KL divergence between two distributions. Output dimension is 1D, one entry
    165     per distribution in `p`.
    166 
    167   Raises:
    168     ValueError: If any of the inputs aren't floating-point.
    169     ValueError: If p or p_logits aren't 2D.
    170     ValueError: If q isn't 1D.
    171   """
    172   for tensor in [p, p_logits, q]:
    173     if not tensor.dtype.is_floating:
    174       raise ValueError('Input %s must be floating type.', tensor.name)
    175   p.shape.assert_has_rank(2)
    176   p_logits.shape.assert_has_rank(2)
    177   q.shape.assert_has_rank(1)
    178   return math_ops.reduce_sum(
    179       p * (nn_ops.log_softmax(p_logits) - math_ops.log(q)), axis=1)
    180 
    181 
    182 def get_graph_def_from_disk(filename):
    183   """Get a GraphDef proto from a disk location."""
    184   with gfile.FastGFile(filename, 'rb') as f:
    185     return graph_pb2.GraphDef.FromString(f.read())
    186 
    187 
    188 def get_graph_def_from_resource(filename):
    189   """Get a GraphDef proto from within a .par file."""
    190   return graph_pb2.GraphDef.FromString(resource_loader.load_resource(filename))
    191 
    192 
    193 def get_graph_def_from_url_tarball(url, filename, tar_filename=None):
    194   """Get a GraphDef proto from a tarball on the web.
    195 
    196   Args:
    197     url: Web address of tarball
    198     filename: Filename of graph definition within tarball
    199     tar_filename: Temporary download filename (None = always download)
    200 
    201   Returns:
    202     A GraphDef loaded from a file in the downloaded tarball.
    203   """
    204   if not (tar_filename and os.path.exists(tar_filename)):
    205 
    206     def _progress(count, block_size, total_size):
    207       sys.stdout.write('\r>> Downloading %s %.1f%%' %
    208                        (url,
    209                         float(count * block_size) / float(total_size) * 100.0))
    210       sys.stdout.flush()
    211 
    212     tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress)
    213   with tarfile.open(tar_filename, 'r:gz') as tar:
    214     proto_str = tar.extractfile(filename).read()
    215   return graph_pb2.GraphDef.FromString(proto_str)
    216 
    217 
    218 def _default_graph_def_fn():
    219   return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH,
    220                                         os.path.basename(INCEPTION_URL))
    221 
    222 
    223 def run_inception(images,
    224                   graph_def=None,
    225                   default_graph_def_fn=_default_graph_def_fn,
    226                   image_size=INCEPTION_DEFAULT_IMAGE_SIZE,
    227                   input_tensor=INCEPTION_INPUT,
    228                   output_tensor=INCEPTION_OUTPUT):
    229   """Run images through a pretrained Inception classifier.
    230 
    231   Args:
    232     images: Input tensors. Must be [batch, height, width, channels]. Input shape
    233       and values must be in [-1, 1], which can be achieved using
    234       `preprocess_image`.
    235     graph_def: A GraphDef proto of a pretrained Inception graph. If `None`,
    236       call `default_graph_def_fn` to get GraphDef.
    237     default_graph_def_fn: A function that returns a GraphDef. Used if
    238       `graph_def` is `None. By default, returns a pretrained InceptionV3 graph.
    239     image_size: Required image width and height. See unit tests for the default
    240       values.
    241     input_tensor: Name of input Tensor.
    242     output_tensor: Name or list of output Tensors. This function will compute
    243       activations at the specified layer. Examples include INCEPTION_V3_OUTPUT
    244       and INCEPTION_V3_FINAL_POOL which would result in this function computing
    245       the final logits or the penultimate pooling layer.
    246 
    247   Returns:
    248     Tensor or Tensors corresponding to computed `output_tensor`.
    249 
    250   Raises:
    251     ValueError: If images are not the correct size.
    252     ValueError: If neither `graph_def` nor `default_graph_def_fn` are provided.
    253   """
    254   images = _validate_images(images, image_size)
    255 
    256   if graph_def is None:
    257     if default_graph_def_fn is None:
    258       raise ValueError('If `graph_def` is `None`, must provide '
    259                        '`default_graph_def_fn`.')
    260     graph_def = default_graph_def_fn()
    261 
    262   activations = run_image_classifier(images, graph_def, input_tensor,
    263                                      output_tensor)
    264   if isinstance(activations, list):
    265     for i, activation in enumerate(activations):
    266       if array_ops.rank(activation) != 2:
    267         activations[i] = layers.flatten(activation)
    268   else:
    269     if array_ops.rank(activations) != 2:
    270       activations = layers.flatten(activations)
    271 
    272   return activations
    273 
    274 
    275 def run_image_classifier(tensor, graph_def, input_tensor,
    276                          output_tensor, scope='RunClassifier'):
    277   """Runs a network from a frozen graph.
    278 
    279   Args:
    280     tensor: An Input tensor.
    281     graph_def: A GraphDef proto.
    282     input_tensor: Name of input tensor in graph def.
    283     output_tensor: A tensor name or list of tensor names in graph def.
    284     scope: Name scope for classifier.
    285 
    286   Returns:
    287     Classifier output if `output_tensor` is a string, or a list of outputs if
    288     `output_tensor` is a list.
    289 
    290   Raises:
    291     ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def.
    292   """
    293   input_map = {input_tensor: tensor}
    294   is_singleton = isinstance(output_tensor, str)
    295   if is_singleton:
    296     output_tensor = [output_tensor]
    297   classifier_outputs = importer.import_graph_def(
    298       graph_def, input_map, output_tensor, name=scope)
    299   if is_singleton:
    300     classifier_outputs = classifier_outputs[0]
    301 
    302   return classifier_outputs
    303 
    304 
    305 def classifier_score(images, classifier_fn, num_batches=1):
    306   """Classifier score for evaluating a conditional generative model.
    307 
    308   This is based on the Inception Score, but for an arbitrary classifier.
    309 
    310   This technique is described in detail in https://arxiv.org/abs/1606.03498. In
    311   summary, this function calculates
    312 
    313   exp( E[ KL(p(y|x) || p(y)) ] )
    314 
    315   which captures how different the network's classification prediction is from
    316   the prior distribution over classes.
    317 
    318   NOTE: This function consumes images, computes their logits, and then
    319   computes the classifier score. If you would like to precompute many logits for
    320   large batches, use clasifier_score_from_logits(), which this method also
    321   uses.
    322 
    323   Args:
    324     images: Images to calculate the classifier score for.
    325     classifier_fn: A function that takes images and produces logits based on a
    326       classifier.
    327     num_batches: Number of batches to split `generated_images` in to in order to
    328       efficiently run them through the classifier network.
    329 
    330   Returns:
    331     The classifier score. A floating-point scalar of the same type as the output
    332     of `classifier_fn`.
    333   """
    334   generated_images_list = array_ops.split(
    335       images, num_or_size_splits=num_batches)
    336 
    337   # Compute the classifier splits using the memory-efficient `map_fn`.
    338   logits = functional_ops.map_fn(
    339       fn=classifier_fn,
    340       elems=array_ops.stack(generated_images_list),
    341       parallel_iterations=1,
    342       back_prop=False,
    343       swap_memory=True,
    344       name='RunClassifier')
    345   logits = array_ops.concat(array_ops.unstack(logits), 0)
    346 
    347   return classifier_score_from_logits(logits)
    348 
    349 
    350 def classifier_score_from_logits(logits):
    351   """Classifier score for evaluating a generative model from logits.
    352 
    353   This method computes the classifier score for a set of logits. This can be
    354   used independently of the classifier_score() method, especially in the case
    355   of using large batches during evaluation where we would like precompute all
    356   of the logits before computing the classifier score.
    357 
    358   This technique is described in detail in https://arxiv.org/abs/1606.03498. In
    359   summary, this function calculates:
    360 
    361   exp( E[ KL(p(y|x) || p(y)) ] )
    362 
    363   which captures how different the network's classification prediction is from
    364   the prior distribution over classes.
    365 
    366   Args:
    367     logits: Precomputed 2D tensor of logits that will be used to
    368       compute the classifier score.
    369 
    370   Returns:
    371     The classifier score. A floating-point scalar of the same type as the output
    372     of `logits`.
    373   """
    374   logits.shape.assert_has_rank(2)
    375 
    376   # Use maximum precision for best results.
    377   logits_dtype = logits.dtype
    378   if logits_dtype != dtypes.float64:
    379     logits = math_ops.to_double(logits)
    380 
    381   p = nn_ops.softmax(logits)
    382   q = math_ops.reduce_mean(p, axis=0)
    383   kl = _kl_divergence(p, logits, q)
    384   kl.shape.assert_has_rank(1)
    385   log_score = math_ops.reduce_mean(kl)
    386   final_score = math_ops.exp(log_score)
    387 
    388   if logits_dtype != dtypes.float64:
    389     final_score = math_ops.cast(final_score, logits_dtype)
    390 
    391   return final_score
    392 
    393 
    394 inception_score = functools.partial(
    395     classifier_score,
    396     classifier_fn=functools.partial(
    397         run_inception, output_tensor=INCEPTION_OUTPUT))
    398 
    399 
    400 def trace_sqrt_product(sigma, sigma_v):
    401   """Find the trace of the positive sqrt of product of covariance matrices.
    402 
    403   '_symmetric_matrix_square_root' only works for symmetric matrices, so we
    404   cannot just take _symmetric_matrix_square_root(sigma * sigma_v).
    405   ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily).
    406 
    407   Let sigma = A A so A = sqrt(sigma), and sigma_v = B B.
    408   We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B))
    409   Note the following properties:
    410   (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1)
    411      => eigenvalues(A A B B) = eigenvalues (A B B A)
    412   (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2))
    413      => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A))
    414   (iii) forall M: trace(M) = sum(eigenvalues(M))
    415      => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v)))
    416                                    = sum(sqrt(eigenvalues(A B B A)))
    417                                    = sum(eigenvalues(sqrt(A B B A)))
    418                                    = trace(sqrt(A B B A))
    419                                    = trace(sqrt(A sigma_v A))
    420   A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can**
    421   use the _symmetric_matrix_square_root function to find the roots of these
    422   matrices.
    423 
    424   Args:
    425     sigma: a square, symmetric, real, positive semi-definite covariance matrix
    426     sigma_v: same as sigma
    427 
    428   Returns:
    429     The trace of the positive square root of sigma*sigma_v
    430   """
    431 
    432   # Note sqrt_sigma is called "A" in the proof above
    433   sqrt_sigma = _symmetric_matrix_square_root(sigma)
    434 
    435   # This is sqrt(A sigma_v A) above
    436   sqrt_a_sigmav_a = math_ops.matmul(
    437       sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma))
    438 
    439   return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a))
    440 
    441 
    442 def frechet_classifier_distance(real_images,
    443                                 generated_images,
    444                                 classifier_fn,
    445                                 num_batches=1):
    446   """Classifier distance for evaluating a generative model.
    447 
    448   This is based on the Frechet Inception distance, but for an arbitrary
    449   classifier.
    450 
    451   This technique is described in detail in https://arxiv.org/abs/1706.08500.
    452   Given two Gaussian distribution with means m and m_w and covariance matrices
    453   C and C_w, this function calcuates
    454 
    455   |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
    456 
    457   which captures how different the distributions of real images and generated
    458   images (or more accurately, their visual features) are. Note that unlike the
    459   Inception score, this is a true distance and utilizes information about real
    460   world images.
    461 
    462   Note that when computed using sample means and sample covariance matrices,
    463   Frechet distance is biased. It is more biased for small sample sizes. (e.g.
    464   even if the two distributions are the same, for a small sample size, the
    465   expected Frechet distance is large). It is important to use the same
    466   sample size to compute frechet classifier distance when comparing two
    467   generative models.
    468 
    469   NOTE: This function consumes images, computes their activations, and then
    470   computes the classifier score. If you would like to precompute many
    471   activations for real and generated images for large batches, please use
    472   frechet_clasifier_distance_from_activations(), which this method also uses.
    473 
    474   Args:
    475     real_images: Real images to use to compute Frechet Inception distance.
    476     generated_images: Generated images to use to compute Frechet Inception
    477       distance.
    478     classifier_fn: A function that takes images and produces activations
    479       based on a classifier.
    480     num_batches: Number of batches to split images in to in order to
    481       efficiently run them through the classifier network.
    482 
    483   Returns:
    484     The Frechet Inception distance. A floating-point scalar of the same type
    485     as the output of `classifier_fn`.
    486   """
    487 
    488   real_images_list = array_ops.split(
    489       real_images, num_or_size_splits=num_batches)
    490   generated_images_list = array_ops.split(
    491       generated_images, num_or_size_splits=num_batches)
    492 
    493   imgs = array_ops.stack(real_images_list + generated_images_list)
    494 
    495   # Compute the activations using the memory-efficient `map_fn`.
    496   activations = functional_ops.map_fn(
    497       fn=classifier_fn,
    498       elems=imgs,
    499       parallel_iterations=1,
    500       back_prop=False,
    501       swap_memory=True,
    502       name='RunClassifier')
    503 
    504   # Split the activations by the real and generated images.
    505   real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0)
    506 
    507   # Ensure the activations have the right shapes.
    508   real_a = array_ops.concat(array_ops.unstack(real_a), 0)
    509   gen_a = array_ops.concat(array_ops.unstack(gen_a), 0)
    510 
    511   return frechet_classifier_distance_from_activations(real_a, gen_a)
    512 
    513 
    514 def frechet_classifier_distance_from_activations(
    515     real_activations, generated_activations):
    516   """Classifier distance for evaluating a generative model from activations.
    517 
    518   This methods computes the Frechet classifier distance from activations of
    519   real images and generated images. This can be used independently of the
    520   frechet_classifier_distance() method, especially in the case of using large
    521   batches during evaluation where we would like precompute all of the
    522   activations before computing the classifier distance.
    523 
    524   This technique is described in detail in https://arxiv.org/abs/1706.08500.
    525   Given two Gaussian distribution with means m and m_w and covariance matrices
    526   C and C_w, this function calcuates
    527 
    528   |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2))
    529 
    530   which captures how different the distributions of real images and generated
    531   images (or more accurately, their visual features) are. Note that unlike the
    532   Inception score, this is a true distance and utilizes information about real
    533   world images.
    534 
    535   Args:
    536     real_activations: 2D Tensor containing activations of real data. Shape is
    537       [batch_size, activation_size].
    538     generated_activations: 2D Tensor containing activations of generated data.
    539       Shape is [batch_size, activation_size].
    540 
    541   Returns:
    542    The Frechet Inception distance. A floating-point scalar of the same type
    543    as the output of the activations.
    544 
    545   """
    546   real_activations.shape.assert_has_rank(2)
    547   generated_activations.shape.assert_has_rank(2)
    548 
    549   activations_dtype = real_activations.dtype
    550   if activations_dtype != dtypes.float64:
    551     real_activations = math_ops.to_double(real_activations)
    552     generated_activations = math_ops.to_double(generated_activations)
    553 
    554   # Compute mean and covariance matrices of activations.
    555   m = math_ops.reduce_mean(real_activations, 0)
    556   m_v = math_ops.reduce_mean(generated_activations, 0)
    557   num_examples = math_ops.to_double(array_ops.shape(real_activations)[0])
    558 
    559   # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T
    560   real_centered = real_activations - m
    561   sigma = math_ops.matmul(
    562       real_centered, real_centered, transpose_a=True) / (num_examples - 1)
    563 
    564   gen_centered = generated_activations - m_v
    565   sigma_v = math_ops.matmul(
    566       gen_centered, gen_centered, transpose_a=True) / (num_examples - 1)
    567 
    568   # Find the Tr(sqrt(sigma sigma_v)) component of FID
    569   sqrt_trace_component = trace_sqrt_product(sigma, sigma_v)
    570 
    571   # Compute the two components of FID.
    572 
    573   # First the covariance component.
    574   # Here, note that trace(A + B) = trace(A) + trace(B)
    575   trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component
    576 
    577   # Next the distance between means.
    578   mean = math_ops.square(linalg_ops.norm(m - m_v))  # This uses the L2 norm.
    579   fid = trace + mean
    580   if activations_dtype != dtypes.float64:
    581     fid = math_ops.cast(fid, activations_dtype)
    582 
    583   return fid
    584 
    585 
    586 frechet_inception_distance = functools.partial(
    587     frechet_classifier_distance,
    588     classifier_fn=functools.partial(
    589         run_inception, output_tensor=INCEPTION_FINAL_POOL))
    590