1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Model evaluation tools for TFGAN. 16 17 These methods come from https://arxiv.org/abs/1606.03498 and 18 https://arxiv.org/abs/1706.08500. 19 20 NOTE: This implementation uses the same weights as in 21 https://github.com/openai/improved-gan/blob/master/inception_score/model.py, 22 but is more numerically stable and is an unbiased estimator of the true 23 Inception score even when splitting the inputs into batches. 24 """ 25 26 from __future__ import absolute_import 27 from __future__ import division 28 from __future__ import print_function 29 30 import functools 31 import os 32 import sys 33 import tarfile 34 35 from six.moves import urllib 36 37 from tensorflow.contrib.layers.python.layers import layers 38 from tensorflow.core.framework import graph_pb2 39 from tensorflow.python.framework import dtypes 40 from tensorflow.python.framework import importer 41 from tensorflow.python.framework import ops 42 from tensorflow.python.ops import array_ops 43 from tensorflow.python.ops import functional_ops 44 from tensorflow.python.ops import image_ops 45 from tensorflow.python.ops import linalg_ops 46 from tensorflow.python.ops import math_ops 47 from tensorflow.python.ops import nn_ops 48 from tensorflow.python.platform import gfile 49 from tensorflow.python.platform import resource_loader 50 51 52 __all__ = [ 53 'get_graph_def_from_disk', 54 'get_graph_def_from_resource', 55 'get_graph_def_from_url_tarball', 56 'preprocess_image', 57 'run_image_classifier', 58 'run_inception', 59 'inception_score', 60 'classifier_score', 61 'classifier_score_from_logits', 62 'frechet_inception_distance', 63 'frechet_classifier_distance', 64 'frechet_classifier_distance_from_activations', 65 'INCEPTION_DEFAULT_IMAGE_SIZE', 66 ] 67 68 69 INCEPTION_URL = 'http://download.tensorflow.org/models/frozen_inception_v1_2015_12_05.tar.gz' 70 INCEPTION_FROZEN_GRAPH = 'inceptionv1_for_inception_score.pb' 71 INCEPTION_INPUT = 'Mul:0' 72 INCEPTION_OUTPUT = 'logits:0' 73 INCEPTION_FINAL_POOL = 'pool_3:0' 74 INCEPTION_DEFAULT_IMAGE_SIZE = 299 75 76 77 def _validate_images(images, image_size): 78 images = ops.convert_to_tensor(images) 79 images.shape.with_rank(4) 80 images.shape.assert_is_compatible_with( 81 [None, image_size, image_size, None]) 82 return images 83 84 85 def _symmetric_matrix_square_root(mat, eps=1e-10): 86 """Compute square root of a symmetric matrix. 87 88 Note that this is different from an elementwise square root. We want to 89 compute M' where M' = sqrt(mat) such that M' * M' = mat. 90 91 Also note that this method **only** works for symmetric matrices. 92 93 Args: 94 mat: Matrix to take the square root of. 95 eps: Small epsilon such that any element less than eps will not be square 96 rooted to guard against numerical instability. 97 98 Returns: 99 Matrix square root of mat. 100 """ 101 # Unlike numpy, tensorflow's return order is (s, u, v) 102 s, u, v = linalg_ops.svd(mat) 103 # sqrt is unstable around 0, just use 0 in such case 104 si = array_ops.where(math_ops.less(s, eps), s, math_ops.sqrt(s)) 105 # Note that the v returned by Tensorflow is v = V 106 # (when referencing the equation A = U S V^T) 107 # This is unlike Numpy which returns v = V^T 108 return math_ops.matmul( 109 math_ops.matmul(u, array_ops.diag(si)), v, transpose_b=True) 110 111 112 def preprocess_image( 113 images, height=INCEPTION_DEFAULT_IMAGE_SIZE, 114 width=INCEPTION_DEFAULT_IMAGE_SIZE, scope=None): 115 """Prepare a batch of images for evaluation. 116 117 This is the preprocessing portion of the graph from 118 http://download.tensorflow.org/models/image/imagenet/inception-2015-12-05.tgz. 119 120 Note that it expects Tensors in [0, 255]. This function maps pixel values to 121 [-1, 1] and resizes to match the InceptionV1 network. 122 123 Args: 124 images: 3-D or 4-D Tensor of images. Values are in [0, 255]. 125 height: Integer. Height of resized output image. 126 width: Integer. Width of resized output image. 127 scope: Optional scope for name_scope. 128 129 Returns: 130 3-D or 4-D float Tensor of prepared image(s). Values are in [-1, 1]. 131 """ 132 is_single = images.shape.ndims == 3 133 with ops.name_scope(scope, 'preprocess', [images, height, width]): 134 if not images.dtype.is_floating: 135 images = math_ops.to_float(images) 136 if is_single: 137 images = array_ops.expand_dims(images, axis=0) 138 resized = image_ops.resize_bilinear(images, [height, width]) 139 resized = (resized - 128.0) / 128.0 140 if is_single: 141 resized = array_ops.squeeze(resized, axis=0) 142 return resized 143 144 145 def _kl_divergence(p, p_logits, q): 146 """Computes the Kullback-Liebler divergence between p and q. 147 148 This function uses p's logits in some places to improve numerical stability. 149 150 Specifically: 151 152 KL(p || q) = sum[ p * log(p / q) ] 153 = sum[ p * ( log(p) - log(q) ) ] 154 = sum[ p * ( log_softmax(p_logits) - log(q) ) ] 155 156 Args: 157 p: A 2-D floating-point Tensor p_ij, where `i` corresponds to the minibatch 158 example and `j` corresponds to the probability of being in class `j`. 159 p_logits: A 2-D floating-point Tensor corresponding to logits for `p`. 160 q: A 1-D floating-point Tensor, where q_j corresponds to the probability 161 of class `j`. 162 163 Returns: 164 KL divergence between two distributions. Output dimension is 1D, one entry 165 per distribution in `p`. 166 167 Raises: 168 ValueError: If any of the inputs aren't floating-point. 169 ValueError: If p or p_logits aren't 2D. 170 ValueError: If q isn't 1D. 171 """ 172 for tensor in [p, p_logits, q]: 173 if not tensor.dtype.is_floating: 174 raise ValueError('Input %s must be floating type.', tensor.name) 175 p.shape.assert_has_rank(2) 176 p_logits.shape.assert_has_rank(2) 177 q.shape.assert_has_rank(1) 178 return math_ops.reduce_sum( 179 p * (nn_ops.log_softmax(p_logits) - math_ops.log(q)), axis=1) 180 181 182 def get_graph_def_from_disk(filename): 183 """Get a GraphDef proto from a disk location.""" 184 with gfile.FastGFile(filename, 'rb') as f: 185 return graph_pb2.GraphDef.FromString(f.read()) 186 187 188 def get_graph_def_from_resource(filename): 189 """Get a GraphDef proto from within a .par file.""" 190 return graph_pb2.GraphDef.FromString(resource_loader.load_resource(filename)) 191 192 193 def get_graph_def_from_url_tarball(url, filename, tar_filename=None): 194 """Get a GraphDef proto from a tarball on the web. 195 196 Args: 197 url: Web address of tarball 198 filename: Filename of graph definition within tarball 199 tar_filename: Temporary download filename (None = always download) 200 201 Returns: 202 A GraphDef loaded from a file in the downloaded tarball. 203 """ 204 if not (tar_filename and os.path.exists(tar_filename)): 205 206 def _progress(count, block_size, total_size): 207 sys.stdout.write('\r>> Downloading %s %.1f%%' % 208 (url, 209 float(count * block_size) / float(total_size) * 100.0)) 210 sys.stdout.flush() 211 212 tar_filename, _ = urllib.request.urlretrieve(url, tar_filename, _progress) 213 with tarfile.open(tar_filename, 'r:gz') as tar: 214 proto_str = tar.extractfile(filename).read() 215 return graph_pb2.GraphDef.FromString(proto_str) 216 217 218 def _default_graph_def_fn(): 219 return get_graph_def_from_url_tarball(INCEPTION_URL, INCEPTION_FROZEN_GRAPH, 220 os.path.basename(INCEPTION_URL)) 221 222 223 def run_inception(images, 224 graph_def=None, 225 default_graph_def_fn=_default_graph_def_fn, 226 image_size=INCEPTION_DEFAULT_IMAGE_SIZE, 227 input_tensor=INCEPTION_INPUT, 228 output_tensor=INCEPTION_OUTPUT): 229 """Run images through a pretrained Inception classifier. 230 231 Args: 232 images: Input tensors. Must be [batch, height, width, channels]. Input shape 233 and values must be in [-1, 1], which can be achieved using 234 `preprocess_image`. 235 graph_def: A GraphDef proto of a pretrained Inception graph. If `None`, 236 call `default_graph_def_fn` to get GraphDef. 237 default_graph_def_fn: A function that returns a GraphDef. Used if 238 `graph_def` is `None. By default, returns a pretrained InceptionV3 graph. 239 image_size: Required image width and height. See unit tests for the default 240 values. 241 input_tensor: Name of input Tensor. 242 output_tensor: Name or list of output Tensors. This function will compute 243 activations at the specified layer. Examples include INCEPTION_V3_OUTPUT 244 and INCEPTION_V3_FINAL_POOL which would result in this function computing 245 the final logits or the penultimate pooling layer. 246 247 Returns: 248 Tensor or Tensors corresponding to computed `output_tensor`. 249 250 Raises: 251 ValueError: If images are not the correct size. 252 ValueError: If neither `graph_def` nor `default_graph_def_fn` are provided. 253 """ 254 images = _validate_images(images, image_size) 255 256 if graph_def is None: 257 if default_graph_def_fn is None: 258 raise ValueError('If `graph_def` is `None`, must provide ' 259 '`default_graph_def_fn`.') 260 graph_def = default_graph_def_fn() 261 262 activations = run_image_classifier(images, graph_def, input_tensor, 263 output_tensor) 264 if isinstance(activations, list): 265 for i, activation in enumerate(activations): 266 if array_ops.rank(activation) != 2: 267 activations[i] = layers.flatten(activation) 268 else: 269 if array_ops.rank(activations) != 2: 270 activations = layers.flatten(activations) 271 272 return activations 273 274 275 def run_image_classifier(tensor, graph_def, input_tensor, 276 output_tensor, scope='RunClassifier'): 277 """Runs a network from a frozen graph. 278 279 Args: 280 tensor: An Input tensor. 281 graph_def: A GraphDef proto. 282 input_tensor: Name of input tensor in graph def. 283 output_tensor: A tensor name or list of tensor names in graph def. 284 scope: Name scope for classifier. 285 286 Returns: 287 Classifier output if `output_tensor` is a string, or a list of outputs if 288 `output_tensor` is a list. 289 290 Raises: 291 ValueError: If `input_tensor` or `output_tensor` aren't in the graph_def. 292 """ 293 input_map = {input_tensor: tensor} 294 is_singleton = isinstance(output_tensor, str) 295 if is_singleton: 296 output_tensor = [output_tensor] 297 classifier_outputs = importer.import_graph_def( 298 graph_def, input_map, output_tensor, name=scope) 299 if is_singleton: 300 classifier_outputs = classifier_outputs[0] 301 302 return classifier_outputs 303 304 305 def classifier_score(images, classifier_fn, num_batches=1): 306 """Classifier score for evaluating a conditional generative model. 307 308 This is based on the Inception Score, but for an arbitrary classifier. 309 310 This technique is described in detail in https://arxiv.org/abs/1606.03498. In 311 summary, this function calculates 312 313 exp( E[ KL(p(y|x) || p(y)) ] ) 314 315 which captures how different the network's classification prediction is from 316 the prior distribution over classes. 317 318 NOTE: This function consumes images, computes their logits, and then 319 computes the classifier score. If you would like to precompute many logits for 320 large batches, use clasifier_score_from_logits(), which this method also 321 uses. 322 323 Args: 324 images: Images to calculate the classifier score for. 325 classifier_fn: A function that takes images and produces logits based on a 326 classifier. 327 num_batches: Number of batches to split `generated_images` in to in order to 328 efficiently run them through the classifier network. 329 330 Returns: 331 The classifier score. A floating-point scalar of the same type as the output 332 of `classifier_fn`. 333 """ 334 generated_images_list = array_ops.split( 335 images, num_or_size_splits=num_batches) 336 337 # Compute the classifier splits using the memory-efficient `map_fn`. 338 logits = functional_ops.map_fn( 339 fn=classifier_fn, 340 elems=array_ops.stack(generated_images_list), 341 parallel_iterations=1, 342 back_prop=False, 343 swap_memory=True, 344 name='RunClassifier') 345 logits = array_ops.concat(array_ops.unstack(logits), 0) 346 347 return classifier_score_from_logits(logits) 348 349 350 def classifier_score_from_logits(logits): 351 """Classifier score for evaluating a generative model from logits. 352 353 This method computes the classifier score for a set of logits. This can be 354 used independently of the classifier_score() method, especially in the case 355 of using large batches during evaluation where we would like precompute all 356 of the logits before computing the classifier score. 357 358 This technique is described in detail in https://arxiv.org/abs/1606.03498. In 359 summary, this function calculates: 360 361 exp( E[ KL(p(y|x) || p(y)) ] ) 362 363 which captures how different the network's classification prediction is from 364 the prior distribution over classes. 365 366 Args: 367 logits: Precomputed 2D tensor of logits that will be used to 368 compute the classifier score. 369 370 Returns: 371 The classifier score. A floating-point scalar of the same type as the output 372 of `logits`. 373 """ 374 logits.shape.assert_has_rank(2) 375 376 # Use maximum precision for best results. 377 logits_dtype = logits.dtype 378 if logits_dtype != dtypes.float64: 379 logits = math_ops.to_double(logits) 380 381 p = nn_ops.softmax(logits) 382 q = math_ops.reduce_mean(p, axis=0) 383 kl = _kl_divergence(p, logits, q) 384 kl.shape.assert_has_rank(1) 385 log_score = math_ops.reduce_mean(kl) 386 final_score = math_ops.exp(log_score) 387 388 if logits_dtype != dtypes.float64: 389 final_score = math_ops.cast(final_score, logits_dtype) 390 391 return final_score 392 393 394 inception_score = functools.partial( 395 classifier_score, 396 classifier_fn=functools.partial( 397 run_inception, output_tensor=INCEPTION_OUTPUT)) 398 399 400 def trace_sqrt_product(sigma, sigma_v): 401 """Find the trace of the positive sqrt of product of covariance matrices. 402 403 '_symmetric_matrix_square_root' only works for symmetric matrices, so we 404 cannot just take _symmetric_matrix_square_root(sigma * sigma_v). 405 ('sigma' and 'sigma_v' are symmetric, but their product is not necessarily). 406 407 Let sigma = A A so A = sqrt(sigma), and sigma_v = B B. 408 We want to find trace(sqrt(sigma sigma_v)) = trace(sqrt(A A B B)) 409 Note the following properties: 410 (i) forall M1, M2: eigenvalues(M1 M2) = eigenvalues(M2 M1) 411 => eigenvalues(A A B B) = eigenvalues (A B B A) 412 (ii) if M1 = sqrt(M2), then eigenvalues(M1) = sqrt(eigenvalues(M2)) 413 => eigenvalues(sqrt(sigma sigma_v)) = sqrt(eigenvalues(A B B A)) 414 (iii) forall M: trace(M) = sum(eigenvalues(M)) 415 => trace(sqrt(sigma sigma_v)) = sum(eigenvalues(sqrt(sigma sigma_v))) 416 = sum(sqrt(eigenvalues(A B B A))) 417 = sum(eigenvalues(sqrt(A B B A))) 418 = trace(sqrt(A B B A)) 419 = trace(sqrt(A sigma_v A)) 420 A = sqrt(sigma). Both sigma and A sigma_v A are symmetric, so we **can** 421 use the _symmetric_matrix_square_root function to find the roots of these 422 matrices. 423 424 Args: 425 sigma: a square, symmetric, real, positive semi-definite covariance matrix 426 sigma_v: same as sigma 427 428 Returns: 429 The trace of the positive square root of sigma*sigma_v 430 """ 431 432 # Note sqrt_sigma is called "A" in the proof above 433 sqrt_sigma = _symmetric_matrix_square_root(sigma) 434 435 # This is sqrt(A sigma_v A) above 436 sqrt_a_sigmav_a = math_ops.matmul( 437 sqrt_sigma, math_ops.matmul(sigma_v, sqrt_sigma)) 438 439 return math_ops.trace(_symmetric_matrix_square_root(sqrt_a_sigmav_a)) 440 441 442 def frechet_classifier_distance(real_images, 443 generated_images, 444 classifier_fn, 445 num_batches=1): 446 """Classifier distance for evaluating a generative model. 447 448 This is based on the Frechet Inception distance, but for an arbitrary 449 classifier. 450 451 This technique is described in detail in https://arxiv.org/abs/1706.08500. 452 Given two Gaussian distribution with means m and m_w and covariance matrices 453 C and C_w, this function calcuates 454 455 |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) 456 457 which captures how different the distributions of real images and generated 458 images (or more accurately, their visual features) are. Note that unlike the 459 Inception score, this is a true distance and utilizes information about real 460 world images. 461 462 Note that when computed using sample means and sample covariance matrices, 463 Frechet distance is biased. It is more biased for small sample sizes. (e.g. 464 even if the two distributions are the same, for a small sample size, the 465 expected Frechet distance is large). It is important to use the same 466 sample size to compute frechet classifier distance when comparing two 467 generative models. 468 469 NOTE: This function consumes images, computes their activations, and then 470 computes the classifier score. If you would like to precompute many 471 activations for real and generated images for large batches, please use 472 frechet_clasifier_distance_from_activations(), which this method also uses. 473 474 Args: 475 real_images: Real images to use to compute Frechet Inception distance. 476 generated_images: Generated images to use to compute Frechet Inception 477 distance. 478 classifier_fn: A function that takes images and produces activations 479 based on a classifier. 480 num_batches: Number of batches to split images in to in order to 481 efficiently run them through the classifier network. 482 483 Returns: 484 The Frechet Inception distance. A floating-point scalar of the same type 485 as the output of `classifier_fn`. 486 """ 487 488 real_images_list = array_ops.split( 489 real_images, num_or_size_splits=num_batches) 490 generated_images_list = array_ops.split( 491 generated_images, num_or_size_splits=num_batches) 492 493 imgs = array_ops.stack(real_images_list + generated_images_list) 494 495 # Compute the activations using the memory-efficient `map_fn`. 496 activations = functional_ops.map_fn( 497 fn=classifier_fn, 498 elems=imgs, 499 parallel_iterations=1, 500 back_prop=False, 501 swap_memory=True, 502 name='RunClassifier') 503 504 # Split the activations by the real and generated images. 505 real_a, gen_a = array_ops.split(activations, [num_batches, num_batches], 0) 506 507 # Ensure the activations have the right shapes. 508 real_a = array_ops.concat(array_ops.unstack(real_a), 0) 509 gen_a = array_ops.concat(array_ops.unstack(gen_a), 0) 510 511 return frechet_classifier_distance_from_activations(real_a, gen_a) 512 513 514 def frechet_classifier_distance_from_activations( 515 real_activations, generated_activations): 516 """Classifier distance for evaluating a generative model from activations. 517 518 This methods computes the Frechet classifier distance from activations of 519 real images and generated images. This can be used independently of the 520 frechet_classifier_distance() method, especially in the case of using large 521 batches during evaluation where we would like precompute all of the 522 activations before computing the classifier distance. 523 524 This technique is described in detail in https://arxiv.org/abs/1706.08500. 525 Given two Gaussian distribution with means m and m_w and covariance matrices 526 C and C_w, this function calcuates 527 528 |m - m_w|^2 + Tr(C + C_w - 2(C * C_w)^(1/2)) 529 530 which captures how different the distributions of real images and generated 531 images (or more accurately, their visual features) are. Note that unlike the 532 Inception score, this is a true distance and utilizes information about real 533 world images. 534 535 Args: 536 real_activations: 2D Tensor containing activations of real data. Shape is 537 [batch_size, activation_size]. 538 generated_activations: 2D Tensor containing activations of generated data. 539 Shape is [batch_size, activation_size]. 540 541 Returns: 542 The Frechet Inception distance. A floating-point scalar of the same type 543 as the output of the activations. 544 545 """ 546 real_activations.shape.assert_has_rank(2) 547 generated_activations.shape.assert_has_rank(2) 548 549 activations_dtype = real_activations.dtype 550 if activations_dtype != dtypes.float64: 551 real_activations = math_ops.to_double(real_activations) 552 generated_activations = math_ops.to_double(generated_activations) 553 554 # Compute mean and covariance matrices of activations. 555 m = math_ops.reduce_mean(real_activations, 0) 556 m_v = math_ops.reduce_mean(generated_activations, 0) 557 num_examples = math_ops.to_double(array_ops.shape(real_activations)[0]) 558 559 # sigma = (1 / (n - 1)) * (X - mu) (X - mu)^T 560 real_centered = real_activations - m 561 sigma = math_ops.matmul( 562 real_centered, real_centered, transpose_a=True) / (num_examples - 1) 563 564 gen_centered = generated_activations - m_v 565 sigma_v = math_ops.matmul( 566 gen_centered, gen_centered, transpose_a=True) / (num_examples - 1) 567 568 # Find the Tr(sqrt(sigma sigma_v)) component of FID 569 sqrt_trace_component = trace_sqrt_product(sigma, sigma_v) 570 571 # Compute the two components of FID. 572 573 # First the covariance component. 574 # Here, note that trace(A + B) = trace(A) + trace(B) 575 trace = math_ops.trace(sigma + sigma_v) - 2.0 * sqrt_trace_component 576 577 # Next the distance between means. 578 mean = math_ops.square(linalg_ops.norm(m - m_v)) # This uses the L2 norm. 579 fid = trace + mean 580 if activations_dtype != dtypes.float64: 581 fid = math_ops.cast(fid, activations_dtype) 582 583 return fid 584 585 586 frechet_inception_distance = functools.partial( 587 frechet_classifier_distance, 588 classifier_fn=functools.partial( 589 run_inception, output_tensor=INCEPTION_FINAL_POOL)) 590