1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Implementation of tf.metrics module.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.eager import context 22 from tensorflow.python.framework import dtypes 23 from tensorflow.python.framework import ops 24 from tensorflow.python.framework import sparse_tensor 25 from tensorflow.python.ops import array_ops 26 from tensorflow.python.ops import check_ops 27 from tensorflow.python.ops import confusion_matrix 28 from tensorflow.python.ops import control_flow_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import nn 31 from tensorflow.python.ops import sets 32 from tensorflow.python.ops import sparse_ops 33 from tensorflow.python.ops import state_ops 34 from tensorflow.python.ops import variable_scope 35 from tensorflow.python.ops import weights_broadcast_ops 36 from tensorflow.python.util.deprecation import deprecated 37 from tensorflow.python.util.tf_export import tf_export 38 39 40 def metric_variable(shape, dtype, validate_shape=True, name=None): 41 """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections.""" 42 43 return variable_scope.variable( 44 lambda: array_ops.zeros(shape, dtype), 45 trainable=False, 46 collections=[ 47 ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES 48 ], 49 validate_shape=validate_shape, 50 name=name) 51 52 53 def _remove_squeezable_dimensions(predictions, labels, weights): 54 """Squeeze or expand last dim if needed. 55 56 Squeezes last dim of `predictions` or `labels` if their rank differs by 1 57 (using confusion_matrix.remove_squeezable_dimensions). 58 Squeezes or expands last dim of `weights` if its rank differs by 1 from the 59 new rank of `predictions`. 60 61 If `weights` is scalar, it is kept scalar. 62 63 This will use static shape if available. Otherwise, it will add graph 64 operations, which could result in a performance hit. 65 66 Args: 67 predictions: Predicted values, a `Tensor` of arbitrary dimensions. 68 labels: Optional label `Tensor` whose dimensions match `predictions`. 69 weights: Optional weight scalar or `Tensor` whose dimensions match 70 `predictions`. 71 72 Returns: 73 Tuple of `predictions`, `labels` and `weights`. Each of them possibly has 74 the last dimension squeezed, `weights` could be extended by one dimension. 75 """ 76 predictions = ops.convert_to_tensor(predictions) 77 if labels is not None: 78 labels, predictions = confusion_matrix.remove_squeezable_dimensions( 79 labels, predictions) 80 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 81 82 if weights is None: 83 return predictions, labels, None 84 85 weights = ops.convert_to_tensor(weights) 86 weights_shape = weights.get_shape() 87 weights_rank = weights_shape.ndims 88 if weights_rank == 0: 89 return predictions, labels, weights 90 91 predictions_shape = predictions.get_shape() 92 predictions_rank = predictions_shape.ndims 93 if (predictions_rank is not None) and (weights_rank is not None): 94 # Use static rank. 95 if weights_rank - predictions_rank == 1: 96 weights = array_ops.squeeze(weights, [-1]) 97 elif predictions_rank - weights_rank == 1: 98 weights = array_ops.expand_dims(weights, [-1]) 99 else: 100 # Use dynamic rank. 101 weights_rank_tensor = array_ops.rank(weights) 102 rank_diff = weights_rank_tensor - array_ops.rank(predictions) 103 104 def _maybe_expand_weights(): 105 return control_flow_ops.cond( 106 math_ops.equal(rank_diff, -1), 107 lambda: array_ops.expand_dims(weights, [-1]), lambda: weights) 108 109 # Don't attempt squeeze if it will fail based on static check. 110 if ((weights_rank is not None) and 111 (not weights_shape.dims[-1].is_compatible_with(1))): 112 maybe_squeeze_weights = lambda: weights 113 else: 114 maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1]) 115 116 def _maybe_adjust_weights(): 117 return control_flow_ops.cond( 118 math_ops.equal(rank_diff, 1), maybe_squeeze_weights, 119 _maybe_expand_weights) 120 121 # If weights are scalar, do nothing. Otherwise, try to add or remove a 122 # dimension to match predictions. 123 weights = control_flow_ops.cond( 124 math_ops.equal(weights_rank_tensor, 0), lambda: weights, 125 _maybe_adjust_weights) 126 return predictions, labels, weights 127 128 129 def _maybe_expand_labels(labels, predictions): 130 """If necessary, expand `labels` along last dimension to match `predictions`. 131 132 Args: 133 labels: `Tensor` or `SparseTensor` with shape 134 [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies 135 num_labels=1, in which case the result is an expanded `labels` with shape 136 [D1, ... DN, 1]. 137 predictions: `Tensor` with shape [D1, ... DN, num_classes]. 138 139 Returns: 140 `labels` with the same rank as `predictions`. 141 142 Raises: 143 ValueError: if `labels` has invalid shape. 144 """ 145 with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope: 146 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 147 148 # If sparse, expand sparse shape. 149 if isinstance(labels, sparse_tensor.SparseTensor): 150 return control_flow_ops.cond( 151 math_ops.equal( 152 array_ops.rank(predictions), 153 array_ops.size(labels.dense_shape) + 1), 154 lambda: sparse_ops.sparse_reshape( # pylint: disable=g-long-lambda 155 labels, 156 shape=array_ops.concat((labels.dense_shape, (1,)), 0), 157 name=scope), 158 lambda: labels) 159 160 # Otherwise, try to use static shape. 161 labels_rank = labels.get_shape().ndims 162 if labels_rank is not None: 163 predictions_rank = predictions.get_shape().ndims 164 if predictions_rank is not None: 165 if predictions_rank == labels_rank: 166 return labels 167 if predictions_rank == labels_rank + 1: 168 return array_ops.expand_dims(labels, -1, name=scope) 169 raise ValueError( 170 'Unexpected labels shape %s for predictions shape %s.' % 171 (labels.get_shape(), predictions.get_shape())) 172 173 # Otherwise, use dynamic shape. 174 return control_flow_ops.cond( 175 math_ops.equal(array_ops.rank(predictions), 176 array_ops.rank(labels) + 1), 177 lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels) 178 179 180 def _safe_div(numerator, denominator, name): 181 """Divides two tensors element-wise, returning 0 if the denominator is <= 0. 182 183 Args: 184 numerator: A real `Tensor`. 185 denominator: A real `Tensor`, with dtype matching `numerator`. 186 name: Name for the returned op. 187 188 Returns: 189 0 if `denominator` <= 0, else `numerator` / `denominator` 190 """ 191 t = math_ops.truediv(numerator, denominator) 192 zero = array_ops.zeros_like(t, dtype=denominator.dtype) 193 condition = math_ops.greater(denominator, zero) 194 zero = math_ops.cast(zero, t.dtype) 195 return array_ops.where(condition, t, zero, name=name) 196 197 198 def _safe_scalar_div(numerator, denominator, name): 199 """Divides two values, returning 0 if the denominator is 0. 200 201 Args: 202 numerator: A scalar `float64` `Tensor`. 203 denominator: A scalar `float64` `Tensor`. 204 name: Name for the returned op. 205 206 Returns: 207 0 if `denominator` == 0, else `numerator` / `denominator` 208 """ 209 numerator.get_shape().with_rank_at_most(1) 210 denominator.get_shape().with_rank_at_most(1) 211 return control_flow_ops.cond( 212 math_ops.equal( 213 array_ops.constant(0.0, dtype=dtypes.float64), denominator), 214 lambda: array_ops.constant(0.0, dtype=dtypes.float64), 215 lambda: math_ops.div(numerator, denominator), 216 name=name) 217 218 219 def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None): 220 """Calculate a streaming confusion matrix. 221 222 Calculates a confusion matrix. For estimation over a stream of data, 223 the function creates an `update_op` operation. 224 225 Args: 226 labels: A `Tensor` of ground truth labels with shape [batch size] and of 227 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 228 predictions: A `Tensor` of prediction results for semantic labels, whose 229 shape is [batch size] and type `int32` or `int64`. The tensor will be 230 flattened if its rank > 1. 231 num_classes: The possible number of labels the prediction task can 232 have. This value must be provided, since a confusion matrix of 233 dimension = [num_classes, num_classes] will be allocated. 234 weights: Optional `Tensor` whose rank is either 0, or the same rank as 235 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 236 be either `1`, or the same as the corresponding `labels` dimension). 237 238 Returns: 239 total_cm: A `Tensor` representing the confusion matrix. 240 update_op: An operation that increments the confusion matrix. 241 """ 242 # Local variable to accumulate the predictions in the confusion matrix. 243 total_cm = metric_variable( 244 [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix') 245 246 # Cast the type to int64 required by confusion_matrix_ops. 247 predictions = math_ops.to_int64(predictions) 248 labels = math_ops.to_int64(labels) 249 num_classes = math_ops.to_int64(num_classes) 250 251 # Flatten the input if its rank > 1. 252 if predictions.get_shape().ndims > 1: 253 predictions = array_ops.reshape(predictions, [-1]) 254 255 if labels.get_shape().ndims > 1: 256 labels = array_ops.reshape(labels, [-1]) 257 258 if (weights is not None) and (weights.get_shape().ndims > 1): 259 weights = array_ops.reshape(weights, [-1]) 260 261 # Accumulate the prediction to current confusion matrix. 262 current_cm = confusion_matrix.confusion_matrix( 263 labels, predictions, num_classes, weights=weights, dtype=dtypes.float64) 264 update_op = state_ops.assign_add(total_cm, current_cm) 265 return total_cm, update_op 266 267 268 @tf_export('metrics.mean') 269 def mean(values, 270 weights=None, 271 metrics_collections=None, 272 updates_collections=None, 273 name=None): 274 """Computes the (weighted) mean of the given values. 275 276 The `mean` function creates two local variables, `total` and `count` 277 that are used to compute the average of `values`. This average is ultimately 278 returned as `mean` which is an idempotent operation that simply divides 279 `total` by `count`. 280 281 For estimation of the metric over a stream of data, the function creates an 282 `update_op` operation that updates these variables and returns the `mean`. 283 `update_op` increments `total` with the reduced sum of the product of `values` 284 and `weights`, and it increments `count` with the reduced sum of `weights`. 285 286 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 287 288 Args: 289 values: A `Tensor` of arbitrary dimensions. 290 weights: Optional `Tensor` whose rank is either 0, or the same rank as 291 `values`, and must be broadcastable to `values` (i.e., all dimensions must 292 be either `1`, or the same as the corresponding `values` dimension). 293 metrics_collections: An optional list of collections that `mean` 294 should be added to. 295 updates_collections: An optional list of collections that `update_op` 296 should be added to. 297 name: An optional variable_scope name. 298 299 Returns: 300 mean: A `Tensor` representing the current mean, the value of `total` divided 301 by `count`. 302 update_op: An operation that increments the `total` and `count` variables 303 appropriately and whose value matches `mean_value`. 304 305 Raises: 306 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 307 or if either `metrics_collections` or `updates_collections` are not a list 308 or tuple. 309 RuntimeError: If eager execution is enabled. 310 """ 311 if context.in_eager_mode(): 312 raise RuntimeError('tf.metrics.mean is not supported when eager execution ' 313 'is enabled.') 314 315 with variable_scope.variable_scope(name, 'mean', (values, weights)): 316 values = math_ops.to_float(values) 317 318 total = metric_variable([], dtypes.float32, name='total') 319 count = metric_variable([], dtypes.float32, name='count') 320 321 if weights is None: 322 num_values = math_ops.to_float(array_ops.size(values)) 323 else: 324 values, _, weights = _remove_squeezable_dimensions( 325 predictions=values, labels=None, weights=weights) 326 weights = weights_broadcast_ops.broadcast_weights( 327 math_ops.to_float(weights), values) 328 values = math_ops.multiply(values, weights) 329 num_values = math_ops.reduce_sum(weights) 330 331 update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values)) 332 with ops.control_dependencies([values]): 333 update_count_op = state_ops.assign_add(count, num_values) 334 335 mean_t = _safe_div(total, count, 'value') 336 update_op = _safe_div(update_total_op, update_count_op, 'update_op') 337 338 if metrics_collections: 339 ops.add_to_collections(metrics_collections, mean_t) 340 341 if updates_collections: 342 ops.add_to_collections(updates_collections, update_op) 343 344 return mean_t, update_op 345 346 347 @tf_export('metrics.accuracy') 348 def accuracy(labels, 349 predictions, 350 weights=None, 351 metrics_collections=None, 352 updates_collections=None, 353 name=None): 354 """Calculates how often `predictions` matches `labels`. 355 356 The `accuracy` function creates two local variables, `total` and 357 `count` that are used to compute the frequency with which `predictions` 358 matches `labels`. This frequency is ultimately returned as `accuracy`: an 359 idempotent operation that simply divides `total` by `count`. 360 361 For estimation of the metric over a stream of data, the function creates an 362 `update_op` operation that updates these variables and returns the `accuracy`. 363 Internally, an `is_correct` operation computes a `Tensor` with elements 1.0 364 where the corresponding elements of `predictions` and `labels` match and 0.0 365 otherwise. Then `update_op` increments `total` with the reduced sum of the 366 product of `weights` and `is_correct`, and it increments `count` with the 367 reduced sum of `weights`. 368 369 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 370 371 Args: 372 labels: The ground truth values, a `Tensor` whose shape matches 373 `predictions`. 374 predictions: The predicted values, a `Tensor` of any shape. 375 weights: Optional `Tensor` whose rank is either 0, or the same rank as 376 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 377 be either `1`, or the same as the corresponding `labels` dimension). 378 metrics_collections: An optional list of collections that `accuracy` should 379 be added to. 380 updates_collections: An optional list of collections that `update_op` should 381 be added to. 382 name: An optional variable_scope name. 383 384 Returns: 385 accuracy: A `Tensor` representing the accuracy, the value of `total` divided 386 by `count`. 387 update_op: An operation that increments the `total` and `count` variables 388 appropriately and whose value matches `accuracy`. 389 390 Raises: 391 ValueError: If `predictions` and `labels` have mismatched shapes, or if 392 `weights` is not `None` and its shape doesn't match `predictions`, or if 393 either `metrics_collections` or `updates_collections` are not a list or 394 tuple. 395 RuntimeError: If eager execution is enabled. 396 """ 397 if context.in_eager_mode(): 398 raise RuntimeError('tf.metrics.accuracy is not supported when eager ' 399 'execution is enabled.') 400 401 predictions, labels, weights = _remove_squeezable_dimensions( 402 predictions=predictions, labels=labels, weights=weights) 403 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 404 if labels.dtype != predictions.dtype: 405 predictions = math_ops.cast(predictions, labels.dtype) 406 is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) 407 return mean(is_correct, weights, metrics_collections, updates_collections, 408 name or 'accuracy') 409 410 411 def _confusion_matrix_at_thresholds(labels, 412 predictions, 413 thresholds, 414 weights=None, 415 includes=None): 416 """Computes true_positives, false_negatives, true_negatives, false_positives. 417 418 This function creates up to four local variables, `true_positives`, 419 `true_negatives`, `false_positives` and `false_negatives`. 420 `true_positive[i]` is defined as the total weight of values in `predictions` 421 above `thresholds[i]` whose corresponding entry in `labels` is `True`. 422 `false_negatives[i]` is defined as the total weight of values in `predictions` 423 at most `thresholds[i]` whose corresponding entry in `labels` is `True`. 424 `true_negatives[i]` is defined as the total weight of values in `predictions` 425 at most `thresholds[i]` whose corresponding entry in `labels` is `False`. 426 `false_positives[i]` is defined as the total weight of values in `predictions` 427 above `thresholds[i]` whose corresponding entry in `labels` is `False`. 428 429 For estimation of these metrics over a stream of data, for each metric the 430 function respectively creates an `update_op` operation that updates the 431 variable and returns its value. 432 433 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 434 435 Args: 436 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 437 `bool`. 438 predictions: A floating point `Tensor` of arbitrary shape and whose values 439 are in the range `[0, 1]`. 440 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 441 weights: Optional `Tensor` whose rank is either 0, or the same rank as 442 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 443 be either `1`, or the same as the corresponding `labels` dimension). 444 includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`, 445 default to all four. 446 447 Returns: 448 values: Dict of variables of shape `[len(thresholds)]`. Keys are from 449 `includes`. 450 update_ops: Dict of operations that increments the `values`. Keys are from 451 `includes`. 452 453 Raises: 454 ValueError: If `predictions` and `labels` have mismatched shapes, or if 455 `weights` is not `None` and its shape doesn't match `predictions`, or if 456 `includes` contains invalid keys. 457 """ 458 all_includes = ('tp', 'fn', 'tn', 'fp') 459 if includes is None: 460 includes = all_includes 461 else: 462 for include in includes: 463 if include not in all_includes: 464 raise ValueError('Invalid key: %s.' % include) 465 466 with ops.control_dependencies([ 467 check_ops.assert_greater_equal( 468 predictions, 469 math_ops.cast(0.0, dtype=predictions.dtype), 470 message='predictions must be in [0, 1]'), 471 check_ops.assert_less_equal( 472 predictions, 473 math_ops.cast(1.0, dtype=predictions.dtype), 474 message='predictions must be in [0, 1]') 475 ]): 476 predictions, labels, weights = _remove_squeezable_dimensions( 477 predictions=math_ops.to_float(predictions), 478 labels=math_ops.cast(labels, dtype=dtypes.bool), 479 weights=weights) 480 481 num_thresholds = len(thresholds) 482 483 # Reshape predictions and labels. 484 predictions_2d = array_ops.reshape(predictions, [-1, 1]) 485 labels_2d = array_ops.reshape( 486 math_ops.cast(labels, dtype=dtypes.bool), [1, -1]) 487 488 # Use static shape if known. 489 num_predictions = predictions_2d.get_shape().as_list()[0] 490 491 # Otherwise use dynamic shape. 492 if num_predictions is None: 493 num_predictions = array_ops.shape(predictions_2d)[0] 494 thresh_tiled = array_ops.tile( 495 array_ops.expand_dims(array_ops.constant(thresholds), [1]), 496 array_ops.stack([1, num_predictions])) 497 498 # Tile the predictions after thresholding them across different thresholds. 499 pred_is_pos = math_ops.greater( 500 array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]), 501 thresh_tiled) 502 if ('fn' in includes) or ('tn' in includes): 503 pred_is_neg = math_ops.logical_not(pred_is_pos) 504 505 # Tile labels by number of thresholds 506 label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1]) 507 if ('fp' in includes) or ('tn' in includes): 508 label_is_neg = math_ops.logical_not(label_is_pos) 509 510 if weights is not None: 511 weights = weights_broadcast_ops.broadcast_weights( 512 math_ops.to_float(weights), predictions) 513 weights_tiled = array_ops.tile( 514 array_ops.reshape(weights, [1, -1]), [num_thresholds, 1]) 515 thresh_tiled.get_shape().assert_is_compatible_with( 516 weights_tiled.get_shape()) 517 else: 518 weights_tiled = None 519 520 values = {} 521 update_ops = {} 522 523 if 'tp' in includes: 524 true_p = metric_variable( 525 [num_thresholds], dtypes.float32, name='true_positives') 526 is_true_positive = math_ops.to_float( 527 math_ops.logical_and(label_is_pos, pred_is_pos)) 528 if weights_tiled is not None: 529 is_true_positive *= weights_tiled 530 update_ops['tp'] = state_ops.assign_add(true_p, 531 math_ops.reduce_sum( 532 is_true_positive, 1)) 533 values['tp'] = true_p 534 535 if 'fn' in includes: 536 false_n = metric_variable( 537 [num_thresholds], dtypes.float32, name='false_negatives') 538 is_false_negative = math_ops.to_float( 539 math_ops.logical_and(label_is_pos, pred_is_neg)) 540 if weights_tiled is not None: 541 is_false_negative *= weights_tiled 542 update_ops['fn'] = state_ops.assign_add(false_n, 543 math_ops.reduce_sum( 544 is_false_negative, 1)) 545 values['fn'] = false_n 546 547 if 'tn' in includes: 548 true_n = metric_variable( 549 [num_thresholds], dtypes.float32, name='true_negatives') 550 is_true_negative = math_ops.to_float( 551 math_ops.logical_and(label_is_neg, pred_is_neg)) 552 if weights_tiled is not None: 553 is_true_negative *= weights_tiled 554 update_ops['tn'] = state_ops.assign_add(true_n, 555 math_ops.reduce_sum( 556 is_true_negative, 1)) 557 values['tn'] = true_n 558 559 if 'fp' in includes: 560 false_p = metric_variable( 561 [num_thresholds], dtypes.float32, name='false_positives') 562 is_false_positive = math_ops.to_float( 563 math_ops.logical_and(label_is_neg, pred_is_pos)) 564 if weights_tiled is not None: 565 is_false_positive *= weights_tiled 566 update_ops['fp'] = state_ops.assign_add(false_p, 567 math_ops.reduce_sum( 568 is_false_positive, 1)) 569 values['fp'] = false_p 570 571 return values, update_ops 572 573 574 @tf_export('metrics.auc') 575 def auc(labels, 576 predictions, 577 weights=None, 578 num_thresholds=200, 579 metrics_collections=None, 580 updates_collections=None, 581 curve='ROC', 582 name=None, 583 summation_method='trapezoidal'): 584 """Computes the approximate AUC via a Riemann sum. 585 586 The `auc` function creates four local variables, `true_positives`, 587 `true_negatives`, `false_positives` and `false_negatives` that are used to 588 compute the AUC. To discretize the AUC curve, a linearly spaced set of 589 thresholds is used to compute pairs of recall and precision values. The area 590 under the ROC-curve is therefore computed using the height of the recall 591 values by the false positive rate, while the area under the PR-curve is the 592 computed using the height of the precision values by the recall. 593 594 This value is ultimately returned as `auc`, an idempotent operation that 595 computes the area under a discretized curve of precision versus recall values 596 (computed using the aforementioned variables). The `num_thresholds` variable 597 controls the degree of discretization with larger numbers of thresholds more 598 closely approximating the true AUC. The quality of the approximation may vary 599 dramatically depending on `num_thresholds`. 600 601 For best results, `predictions` should be distributed approximately uniformly 602 in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC 603 approximation may be poor if this is not the case. Setting `summation_method` 604 to 'minoring' or 'majoring' can help quantify the error in the approximation 605 by providing lower or upper bound estimate of the AUC. 606 607 For estimation of the metric over a stream of data, the function creates an 608 `update_op` operation that updates these variables and returns the `auc`. 609 610 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 611 612 Args: 613 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 614 `bool`. 615 predictions: A floating point `Tensor` of arbitrary shape and whose values 616 are in the range `[0, 1]`. 617 weights: Optional `Tensor` whose rank is either 0, or the same rank as 618 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 619 be either `1`, or the same as the corresponding `labels` dimension). 620 num_thresholds: The number of thresholds to use when discretizing the roc 621 curve. 622 metrics_collections: An optional list of collections that `auc` should be 623 added to. 624 updates_collections: An optional list of collections that `update_op` should 625 be added to. 626 curve: Specifies the name of the curve to be computed, 'ROC' [default] or 627 'PR' for the Precision-Recall-curve. 628 name: An optional variable_scope name. 629 summation_method: Specifies the Riemann summation method used, 'trapezoidal' 630 [default] that applies the trapezoidal rule, 'minoring' that applies 631 left summation for increasing intervals and right summation for decreasing 632 intervals or 'majoring' that applies the opposite. 633 634 Returns: 635 auc: A scalar `Tensor` representing the current area-under-curve. 636 update_op: An operation that increments the `true_positives`, 637 `true_negatives`, `false_positives` and `false_negatives` variables 638 appropriately and whose value matches `auc`. 639 640 Raises: 641 ValueError: If `predictions` and `labels` have mismatched shapes, or if 642 `weights` is not `None` and its shape doesn't match `predictions`, or if 643 either `metrics_collections` or `updates_collections` are not a list or 644 tuple. 645 RuntimeError: If eager execution is enabled. 646 """ 647 if context.in_eager_mode(): 648 raise RuntimeError('tf.metrics.auc is not supported when eager execution ' 649 'is enabled.') 650 651 with variable_scope.variable_scope(name, 'auc', 652 (labels, predictions, weights)): 653 if curve != 'ROC' and curve != 'PR': 654 raise ValueError('curve must be either ROC or PR, %s unknown' % (curve)) 655 kepsilon = 1e-7 # to account for floating point imprecisions 656 thresholds = [ 657 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 658 ] 659 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 660 661 values, update_ops = _confusion_matrix_at_thresholds( 662 labels, predictions, thresholds, weights) 663 664 # Add epsilons to avoid dividing by 0. 665 epsilon = 1.0e-6 666 667 def compute_auc(tp, fn, tn, fp, name): 668 """Computes the roc-auc or pr-auc based on confusion counts.""" 669 rec = math_ops.div(tp + epsilon, tp + fn + epsilon) 670 if curve == 'ROC': 671 fp_rate = math_ops.div(fp, fp + tn + epsilon) 672 x = fp_rate 673 y = rec 674 else: # curve == 'PR'. 675 prec = math_ops.div(tp, tp + fp + epsilon) 676 x = rec 677 y = prec 678 if summation_method == 'trapezoidal': 679 return math_ops.reduce_sum( 680 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 681 (y[:num_thresholds - 1] + y[1:]) / 2.), 682 name=name) 683 elif summation_method == 'minoring': 684 return math_ops.reduce_sum( 685 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 686 math_ops.minimum(y[:num_thresholds - 1], y[1:])), 687 name=name) 688 elif summation_method == 'majoring': 689 return math_ops.reduce_sum( 690 math_ops.multiply(x[:num_thresholds - 1] - x[1:], 691 math_ops.maximum(y[:num_thresholds - 1], y[1:])), 692 name=name) 693 else: 694 raise ValueError('Invalid summation_method: %s' % summation_method) 695 696 # sum up the areas of all the trapeziums 697 auc_value = compute_auc(values['tp'], values['fn'], values['tn'], 698 values['fp'], 'value') 699 update_op = compute_auc(update_ops['tp'], update_ops['fn'], 700 update_ops['tn'], update_ops['fp'], 'update_op') 701 702 if metrics_collections: 703 ops.add_to_collections(metrics_collections, auc_value) 704 705 if updates_collections: 706 ops.add_to_collections(updates_collections, update_op) 707 708 return auc_value, update_op 709 710 711 @tf_export('metrics.mean_absolute_error') 712 def mean_absolute_error(labels, 713 predictions, 714 weights=None, 715 metrics_collections=None, 716 updates_collections=None, 717 name=None): 718 """Computes the mean absolute error between the labels and predictions. 719 720 The `mean_absolute_error` function creates two local variables, 721 `total` and `count` that are used to compute the mean absolute error. This 722 average is weighted by `weights`, and it is ultimately returned as 723 `mean_absolute_error`: an idempotent operation that simply divides `total` by 724 `count`. 725 726 For estimation of the metric over a stream of data, the function creates an 727 `update_op` operation that updates these variables and returns the 728 `mean_absolute_error`. Internally, an `absolute_errors` operation computes the 729 absolute value of the differences between `predictions` and `labels`. Then 730 `update_op` increments `total` with the reduced sum of the product of 731 `weights` and `absolute_errors`, and it increments `count` with the reduced 732 sum of `weights` 733 734 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 735 736 Args: 737 labels: A `Tensor` of the same shape as `predictions`. 738 predictions: A `Tensor` of arbitrary shape. 739 weights: Optional `Tensor` whose rank is either 0, or the same rank as 740 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 741 be either `1`, or the same as the corresponding `labels` dimension). 742 metrics_collections: An optional list of collections that 743 `mean_absolute_error` should be added to. 744 updates_collections: An optional list of collections that `update_op` should 745 be added to. 746 name: An optional variable_scope name. 747 748 Returns: 749 mean_absolute_error: A `Tensor` representing the current mean, the value of 750 `total` divided by `count`. 751 update_op: An operation that increments the `total` and `count` variables 752 appropriately and whose value matches `mean_absolute_error`. 753 754 Raises: 755 ValueError: If `predictions` and `labels` have mismatched shapes, or if 756 `weights` is not `None` and its shape doesn't match `predictions`, or if 757 either `metrics_collections` or `updates_collections` are not a list or 758 tuple. 759 RuntimeError: If eager execution is enabled. 760 """ 761 if context.in_eager_mode(): 762 raise RuntimeError('tf.metrics.mean_absolute_error is not supported ' 763 'when eager execution is enabled.') 764 765 predictions, labels, weights = _remove_squeezable_dimensions( 766 predictions=predictions, labels=labels, weights=weights) 767 absolute_errors = math_ops.abs(predictions - labels) 768 return mean(absolute_errors, weights, metrics_collections, 769 updates_collections, name or 'mean_absolute_error') 770 771 772 @tf_export('metrics.mean_cosine_distance') 773 def mean_cosine_distance(labels, 774 predictions, 775 dim, 776 weights=None, 777 metrics_collections=None, 778 updates_collections=None, 779 name=None): 780 """Computes the cosine distance between the labels and predictions. 781 782 The `mean_cosine_distance` function creates two local variables, 783 `total` and `count` that are used to compute the average cosine distance 784 between `predictions` and `labels`. This average is weighted by `weights`, 785 and it is ultimately returned as `mean_distance`, which is an idempotent 786 operation that simply divides `total` by `count`. 787 788 For estimation of the metric over a stream of data, the function creates an 789 `update_op` operation that updates these variables and returns the 790 `mean_distance`. 791 792 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 793 794 Args: 795 labels: A `Tensor` of arbitrary shape. 796 predictions: A `Tensor` of the same shape as `labels`. 797 dim: The dimension along which the cosine distance is computed. 798 weights: Optional `Tensor` whose rank is either 0, or the same rank as 799 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 800 be either `1`, or the same as the corresponding `labels` dimension). Also, 801 dimension `dim` must be `1`. 802 metrics_collections: An optional list of collections that the metric 803 value variable should be added to. 804 updates_collections: An optional list of collections that the metric update 805 ops should be added to. 806 name: An optional variable_scope name. 807 808 Returns: 809 mean_distance: A `Tensor` representing the current mean, the value of 810 `total` divided by `count`. 811 update_op: An operation that increments the `total` and `count` variables 812 appropriately. 813 814 Raises: 815 ValueError: If `predictions` and `labels` have mismatched shapes, or if 816 `weights` is not `None` and its shape doesn't match `predictions`, or if 817 either `metrics_collections` or `updates_collections` are not a list or 818 tuple. 819 RuntimeError: If eager execution is enabled. 820 """ 821 if context.in_eager_mode(): 822 raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when ' 823 'eager execution is enabled.') 824 825 predictions, labels, weights = _remove_squeezable_dimensions( 826 predictions=predictions, labels=labels, weights=weights) 827 radial_diffs = math_ops.multiply(predictions, labels) 828 radial_diffs = math_ops.reduce_sum( 829 radial_diffs, reduction_indices=[ 830 dim, 831 ], keepdims=True) 832 mean_distance, update_op = mean(radial_diffs, weights, None, None, name or 833 'mean_cosine_distance') 834 mean_distance = math_ops.subtract(1.0, mean_distance) 835 update_op = math_ops.subtract(1.0, update_op) 836 837 if metrics_collections: 838 ops.add_to_collections(metrics_collections, mean_distance) 839 840 if updates_collections: 841 ops.add_to_collections(updates_collections, update_op) 842 843 return mean_distance, update_op 844 845 846 @tf_export('metrics.mean_per_class_accuracy') 847 def mean_per_class_accuracy(labels, 848 predictions, 849 num_classes, 850 weights=None, 851 metrics_collections=None, 852 updates_collections=None, 853 name=None): 854 """Calculates the mean of the per-class accuracies. 855 856 Calculates the accuracy for each class, then takes the mean of that. 857 858 For estimation of the metric over a stream of data, the function creates an 859 `update_op` operation that updates the accuracy of each class and returns 860 them. 861 862 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 863 864 Args: 865 labels: A `Tensor` of ground truth labels with shape [batch size] and of 866 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 867 predictions: A `Tensor` of prediction results for semantic labels, whose 868 shape is [batch size] and type `int32` or `int64`. The tensor will be 869 flattened if its rank > 1. 870 num_classes: The possible number of labels the prediction task can 871 have. This value must be provided, since two variables with shape = 872 [num_classes] will be allocated. 873 weights: Optional `Tensor` whose rank is either 0, or the same rank as 874 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 875 be either `1`, or the same as the corresponding `labels` dimension). 876 metrics_collections: An optional list of collections that 877 `mean_per_class_accuracy' 878 should be added to. 879 updates_collections: An optional list of collections `update_op` should be 880 added to. 881 name: An optional variable_scope name. 882 883 Returns: 884 mean_accuracy: A `Tensor` representing the mean per class accuracy. 885 update_op: An operation that updates the accuracy tensor. 886 887 Raises: 888 ValueError: If `predictions` and `labels` have mismatched shapes, or if 889 `weights` is not `None` and its shape doesn't match `predictions`, or if 890 either `metrics_collections` or `updates_collections` are not a list or 891 tuple. 892 RuntimeError: If eager execution is enabled. 893 """ 894 if context.in_eager_mode(): 895 raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported ' 896 'when eager execution is enabled.') 897 898 with variable_scope.variable_scope(name, 'mean_accuracy', 899 (predictions, labels, weights)): 900 labels = math_ops.to_int64(labels) 901 902 # Flatten the input if its rank > 1. 903 if labels.get_shape().ndims > 1: 904 labels = array_ops.reshape(labels, [-1]) 905 906 if predictions.get_shape().ndims > 1: 907 predictions = array_ops.reshape(predictions, [-1]) 908 909 # Check if shape is compatible. 910 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 911 912 total = metric_variable([num_classes], dtypes.float32, name='total') 913 count = metric_variable([num_classes], dtypes.float32, name='count') 914 915 ones = array_ops.ones([array_ops.size(labels)], dtypes.float32) 916 917 if labels.dtype != predictions.dtype: 918 predictions = math_ops.cast(predictions, labels.dtype) 919 is_correct = math_ops.to_float(math_ops.equal(predictions, labels)) 920 921 if weights is not None: 922 if weights.get_shape().ndims > 1: 923 weights = array_ops.reshape(weights, [-1]) 924 weights = math_ops.to_float(weights) 925 926 is_correct *= weights 927 ones *= weights 928 929 update_total_op = state_ops.scatter_add(total, labels, ones) 930 update_count_op = state_ops.scatter_add(count, labels, is_correct) 931 932 per_class_accuracy = _safe_div(count, total, None) 933 934 mean_accuracy_v = math_ops.reduce_mean( 935 per_class_accuracy, name='mean_accuracy') 936 update_op = _safe_div(update_count_op, update_total_op, name='update_op') 937 938 if metrics_collections: 939 ops.add_to_collections(metrics_collections, mean_accuracy_v) 940 941 if updates_collections: 942 ops.add_to_collections(updates_collections, update_op) 943 944 return mean_accuracy_v, update_op 945 946 947 @tf_export('metrics.mean_iou') 948 def mean_iou(labels, 949 predictions, 950 num_classes, 951 weights=None, 952 metrics_collections=None, 953 updates_collections=None, 954 name=None): 955 """Calculate per-step mean Intersection-Over-Union (mIOU). 956 957 Mean Intersection-Over-Union is a common evaluation metric for 958 semantic image segmentation, which first computes the IOU for each 959 semantic class and then computes the average over classes. 960 IOU is defined as follows: 961 IOU = true_positive / (true_positive + false_positive + false_negative). 962 The predictions are accumulated in a confusion matrix, weighted by `weights`, 963 and mIOU is then calculated from it. 964 965 For estimation of the metric over a stream of data, the function creates an 966 `update_op` operation that updates these variables and returns the `mean_iou`. 967 968 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 969 970 Args: 971 labels: A `Tensor` of ground truth labels with shape [batch size] and of 972 type `int32` or `int64`. The tensor will be flattened if its rank > 1. 973 predictions: A `Tensor` of prediction results for semantic labels, whose 974 shape is [batch size] and type `int32` or `int64`. The tensor will be 975 flattened if its rank > 1. 976 num_classes: The possible number of labels the prediction task can 977 have. This value must be provided, since a confusion matrix of 978 dimension = [num_classes, num_classes] will be allocated. 979 weights: Optional `Tensor` whose rank is either 0, or the same rank as 980 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 981 be either `1`, or the same as the corresponding `labels` dimension). 982 metrics_collections: An optional list of collections that `mean_iou` 983 should be added to. 984 updates_collections: An optional list of collections `update_op` should be 985 added to. 986 name: An optional variable_scope name. 987 988 Returns: 989 mean_iou: A `Tensor` representing the mean intersection-over-union. 990 update_op: An operation that increments the confusion matrix. 991 992 Raises: 993 ValueError: If `predictions` and `labels` have mismatched shapes, or if 994 `weights` is not `None` and its shape doesn't match `predictions`, or if 995 either `metrics_collections` or `updates_collections` are not a list or 996 tuple. 997 RuntimeError: If eager execution is enabled. 998 """ 999 if context.in_eager_mode(): 1000 raise RuntimeError('tf.metrics.mean_iou is not supported when ' 1001 'eager execution is enabled.') 1002 1003 with variable_scope.variable_scope(name, 'mean_iou', 1004 (predictions, labels, weights)): 1005 # Check if shape is compatible. 1006 predictions.get_shape().assert_is_compatible_with(labels.get_shape()) 1007 1008 total_cm, update_op = _streaming_confusion_matrix(labels, predictions, 1009 num_classes, weights) 1010 1011 def compute_mean_iou(name): 1012 """Compute the mean intersection-over-union via the confusion matrix.""" 1013 sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0)) 1014 sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1)) 1015 cm_diag = math_ops.to_float(array_ops.diag_part(total_cm)) 1016 denominator = sum_over_row + sum_over_col - cm_diag 1017 1018 # The mean is only computed over classes that appear in the 1019 # label or prediction tensor. If the denominator is 0, we need to 1020 # ignore the class. 1021 num_valid_entries = math_ops.reduce_sum( 1022 math_ops.cast( 1023 math_ops.not_equal(denominator, 0), dtype=dtypes.float32)) 1024 1025 # If the value of the denominator is 0, set it to 1 to avoid 1026 # zero division. 1027 denominator = array_ops.where( 1028 math_ops.greater(denominator, 0), denominator, 1029 array_ops.ones_like(denominator)) 1030 iou = math_ops.div(cm_diag, denominator) 1031 1032 # If the number of valid entries is 0 (no classes) we return 0. 1033 result = array_ops.where( 1034 math_ops.greater(num_valid_entries, 0), 1035 math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0) 1036 return result 1037 1038 mean_iou_v = compute_mean_iou('mean_iou') 1039 1040 if metrics_collections: 1041 ops.add_to_collections(metrics_collections, mean_iou_v) 1042 1043 if updates_collections: 1044 ops.add_to_collections(updates_collections, update_op) 1045 1046 return mean_iou_v, update_op 1047 1048 1049 @tf_export('metrics.mean_relative_error') 1050 def mean_relative_error(labels, 1051 predictions, 1052 normalizer, 1053 weights=None, 1054 metrics_collections=None, 1055 updates_collections=None, 1056 name=None): 1057 """Computes the mean relative error by normalizing with the given values. 1058 1059 The `mean_relative_error` function creates two local variables, 1060 `total` and `count` that are used to compute the mean relative absolute error. 1061 This average is weighted by `weights`, and it is ultimately returned as 1062 `mean_relative_error`: an idempotent operation that simply divides `total` by 1063 `count`. 1064 1065 For estimation of the metric over a stream of data, the function creates an 1066 `update_op` operation that updates these variables and returns the 1067 `mean_reative_error`. Internally, a `relative_errors` operation divides the 1068 absolute value of the differences between `predictions` and `labels` by the 1069 `normalizer`. Then `update_op` increments `total` with the reduced sum of the 1070 product of `weights` and `relative_errors`, and it increments `count` with the 1071 reduced sum of `weights`. 1072 1073 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1074 1075 Args: 1076 labels: A `Tensor` of the same shape as `predictions`. 1077 predictions: A `Tensor` of arbitrary shape. 1078 normalizer: A `Tensor` of the same shape as `predictions`. 1079 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1080 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1081 be either `1`, or the same as the corresponding `labels` dimension). 1082 metrics_collections: An optional list of collections that 1083 `mean_relative_error` should be added to. 1084 updates_collections: An optional list of collections that `update_op` should 1085 be added to. 1086 name: An optional variable_scope name. 1087 1088 Returns: 1089 mean_relative_error: A `Tensor` representing the current mean, the value of 1090 `total` divided by `count`. 1091 update_op: An operation that increments the `total` and `count` variables 1092 appropriately and whose value matches `mean_relative_error`. 1093 1094 Raises: 1095 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1096 `weights` is not `None` and its shape doesn't match `predictions`, or if 1097 either `metrics_collections` or `updates_collections` are not a list or 1098 tuple. 1099 RuntimeError: If eager execution is enabled. 1100 """ 1101 if context.in_eager_mode(): 1102 raise RuntimeError('tf.metrics.mean_relative_error is not supported when ' 1103 'eager execution is enabled.') 1104 1105 predictions, labels, weights = _remove_squeezable_dimensions( 1106 predictions=predictions, labels=labels, weights=weights) 1107 1108 predictions, normalizer = confusion_matrix.remove_squeezable_dimensions( 1109 predictions, normalizer) 1110 predictions.get_shape().assert_is_compatible_with(normalizer.get_shape()) 1111 relative_errors = array_ops.where( 1112 math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels), 1113 math_ops.div(math_ops.abs(labels - predictions), normalizer)) 1114 return mean(relative_errors, weights, metrics_collections, 1115 updates_collections, name or 'mean_relative_error') 1116 1117 1118 @tf_export('metrics.mean_squared_error') 1119 def mean_squared_error(labels, 1120 predictions, 1121 weights=None, 1122 metrics_collections=None, 1123 updates_collections=None, 1124 name=None): 1125 """Computes the mean squared error between the labels and predictions. 1126 1127 The `mean_squared_error` function creates two local variables, 1128 `total` and `count` that are used to compute the mean squared error. 1129 This average is weighted by `weights`, and it is ultimately returned as 1130 `mean_squared_error`: an idempotent operation that simply divides `total` by 1131 `count`. 1132 1133 For estimation of the metric over a stream of data, the function creates an 1134 `update_op` operation that updates these variables and returns the 1135 `mean_squared_error`. Internally, a `squared_error` operation computes the 1136 element-wise square of the difference between `predictions` and `labels`. Then 1137 `update_op` increments `total` with the reduced sum of the product of 1138 `weights` and `squared_error`, and it increments `count` with the reduced sum 1139 of `weights`. 1140 1141 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1142 1143 Args: 1144 labels: A `Tensor` of the same shape as `predictions`. 1145 predictions: A `Tensor` of arbitrary shape. 1146 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1147 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1148 be either `1`, or the same as the corresponding `labels` dimension). 1149 metrics_collections: An optional list of collections that 1150 `mean_squared_error` should be added to. 1151 updates_collections: An optional list of collections that `update_op` should 1152 be added to. 1153 name: An optional variable_scope name. 1154 1155 Returns: 1156 mean_squared_error: A `Tensor` representing the current mean, the value of 1157 `total` divided by `count`. 1158 update_op: An operation that increments the `total` and `count` variables 1159 appropriately and whose value matches `mean_squared_error`. 1160 1161 Raises: 1162 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1163 `weights` is not `None` and its shape doesn't match `predictions`, or if 1164 either `metrics_collections` or `updates_collections` are not a list or 1165 tuple. 1166 RuntimeError: If eager execution is enabled. 1167 """ 1168 if context.in_eager_mode(): 1169 raise RuntimeError('tf.metrics.mean_squared_error is not supported when ' 1170 'eager execution is enabled.') 1171 1172 predictions, labels, weights = _remove_squeezable_dimensions( 1173 predictions=predictions, labels=labels, weights=weights) 1174 squared_error = math_ops.square(labels - predictions) 1175 return mean(squared_error, weights, metrics_collections, updates_collections, 1176 name or 'mean_squared_error') 1177 1178 1179 @tf_export('metrics.mean_tensor') 1180 def mean_tensor(values, 1181 weights=None, 1182 metrics_collections=None, 1183 updates_collections=None, 1184 name=None): 1185 """Computes the element-wise (weighted) mean of the given tensors. 1186 1187 In contrast to the `mean` function which returns a scalar with the 1188 mean, this function returns an average tensor with the same shape as the 1189 input tensors. 1190 1191 The `mean_tensor` function creates two local variables, 1192 `total_tensor` and `count_tensor` that are used to compute the average of 1193 `values`. This average is ultimately returned as `mean` which is an idempotent 1194 operation that simply divides `total` by `count`. 1195 1196 For estimation of the metric over a stream of data, the function creates an 1197 `update_op` operation that updates these variables and returns the `mean`. 1198 `update_op` increments `total` with the reduced sum of the product of `values` 1199 and `weights`, and it increments `count` with the reduced sum of `weights`. 1200 1201 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1202 1203 Args: 1204 values: A `Tensor` of arbitrary dimensions. 1205 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1206 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1207 be either `1`, or the same as the corresponding `values` dimension). 1208 metrics_collections: An optional list of collections that `mean` 1209 should be added to. 1210 updates_collections: An optional list of collections that `update_op` 1211 should be added to. 1212 name: An optional variable_scope name. 1213 1214 Returns: 1215 mean: A float `Tensor` representing the current mean, the value of `total` 1216 divided by `count`. 1217 update_op: An operation that increments the `total` and `count` variables 1218 appropriately and whose value matches `mean_value`. 1219 1220 Raises: 1221 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1222 or if either `metrics_collections` or `updates_collections` are not a list 1223 or tuple. 1224 RuntimeError: If eager execution is enabled. 1225 """ 1226 if context.in_eager_mode(): 1227 raise RuntimeError('tf.metrics.mean_tensor is not supported when ' 1228 'eager execution is enabled.') 1229 1230 with variable_scope.variable_scope(name, 'mean', (values, weights)): 1231 values = math_ops.to_float(values) 1232 total = metric_variable( 1233 values.get_shape(), dtypes.float32, name='total_tensor') 1234 count = metric_variable( 1235 values.get_shape(), dtypes.float32, name='count_tensor') 1236 1237 num_values = array_ops.ones_like(values) 1238 if weights is not None: 1239 values, _, weights = _remove_squeezable_dimensions( 1240 predictions=values, labels=None, weights=weights) 1241 weights = weights_broadcast_ops.broadcast_weights( 1242 math_ops.to_float(weights), values) 1243 values = math_ops.multiply(values, weights) 1244 num_values = math_ops.multiply(num_values, weights) 1245 1246 update_total_op = state_ops.assign_add(total, values) 1247 with ops.control_dependencies([values]): 1248 update_count_op = state_ops.assign_add(count, num_values) 1249 1250 def compute_mean(total, count, name): 1251 non_zero_count = math_ops.maximum( 1252 count, array_ops.ones_like(count), name=name) 1253 return math_ops.truediv(total, non_zero_count, name=name) 1254 1255 mean_t = compute_mean(total, count, 'value') 1256 update_op = compute_mean(update_total_op, update_count_op, 'update_op') 1257 1258 if metrics_collections: 1259 ops.add_to_collections(metrics_collections, mean_t) 1260 1261 if updates_collections: 1262 ops.add_to_collections(updates_collections, update_op) 1263 1264 return mean_t, update_op 1265 1266 1267 @tf_export('metrics.percentage_below') 1268 def percentage_below(values, 1269 threshold, 1270 weights=None, 1271 metrics_collections=None, 1272 updates_collections=None, 1273 name=None): 1274 """Computes the percentage of values less than the given threshold. 1275 1276 The `percentage_below` function creates two local variables, 1277 `total` and `count` that are used to compute the percentage of `values` that 1278 fall below `threshold`. This rate is weighted by `weights`, and it is 1279 ultimately returned as `percentage` which is an idempotent operation that 1280 simply divides `total` by `count`. 1281 1282 For estimation of the metric over a stream of data, the function creates an 1283 `update_op` operation that updates these variables and returns the 1284 `percentage`. 1285 1286 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1287 1288 Args: 1289 values: A numeric `Tensor` of arbitrary size. 1290 threshold: A scalar threshold. 1291 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1292 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1293 be either `1`, or the same as the corresponding `values` dimension). 1294 metrics_collections: An optional list of collections that the metric 1295 value variable should be added to. 1296 updates_collections: An optional list of collections that the metric update 1297 ops should be added to. 1298 name: An optional variable_scope name. 1299 1300 Returns: 1301 percentage: A `Tensor` representing the current mean, the value of `total` 1302 divided by `count`. 1303 update_op: An operation that increments the `total` and `count` variables 1304 appropriately. 1305 1306 Raises: 1307 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1308 or if either `metrics_collections` or `updates_collections` are not a list 1309 or tuple. 1310 RuntimeError: If eager execution is enabled. 1311 """ 1312 if context.in_eager_mode(): 1313 raise RuntimeError('tf.metrics.percentage_below is not supported when ' 1314 'eager execution is enabled.') 1315 1316 is_below_threshold = math_ops.to_float(math_ops.less(values, threshold)) 1317 return mean(is_below_threshold, weights, metrics_collections, 1318 updates_collections, name or 'percentage_below_threshold') 1319 1320 1321 def _count_condition(values, 1322 weights=None, 1323 metrics_collections=None, 1324 updates_collections=None): 1325 """Sums the weights of cases where the given values are True. 1326 1327 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1328 1329 Args: 1330 values: A `bool` `Tensor` of arbitrary size. 1331 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1332 `values`, and must be broadcastable to `values` (i.e., all dimensions must 1333 be either `1`, or the same as the corresponding `values` dimension). 1334 metrics_collections: An optional list of collections that the metric 1335 value variable should be added to. 1336 updates_collections: An optional list of collections that the metric update 1337 ops should be added to. 1338 1339 Returns: 1340 value_tensor: A `Tensor` representing the current value of the metric. 1341 update_op: An operation that accumulates the error from a batch of data. 1342 1343 Raises: 1344 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1345 or if either `metrics_collections` or `updates_collections` are not a list 1346 or tuple. 1347 """ 1348 check_ops.assert_type(values, dtypes.bool) 1349 count = metric_variable([], dtypes.float32, name='count') 1350 1351 values = math_ops.to_float(values) 1352 if weights is not None: 1353 with ops.control_dependencies((check_ops.assert_rank_in( 1354 weights, (0, array_ops.rank(values))),)): 1355 weights = math_ops.to_float(weights) 1356 values = math_ops.multiply(values, weights) 1357 1358 value_tensor = array_ops.identity(count) 1359 update_op = state_ops.assign_add(count, math_ops.reduce_sum(values)) 1360 1361 if metrics_collections: 1362 ops.add_to_collections(metrics_collections, value_tensor) 1363 1364 if updates_collections: 1365 ops.add_to_collections(updates_collections, update_op) 1366 1367 return value_tensor, update_op 1368 1369 1370 @tf_export('metrics.false_negatives') 1371 def false_negatives(labels, 1372 predictions, 1373 weights=None, 1374 metrics_collections=None, 1375 updates_collections=None, 1376 name=None): 1377 """Computes the total number of false negatives. 1378 1379 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1380 1381 Args: 1382 labels: The ground truth values, a `Tensor` whose dimensions must match 1383 `predictions`. Will be cast to `bool`. 1384 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1385 be cast to `bool`. 1386 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1387 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1388 be either `1`, or the same as the corresponding `labels` dimension). 1389 metrics_collections: An optional list of collections that the metric 1390 value variable should be added to. 1391 updates_collections: An optional list of collections that the metric update 1392 ops should be added to. 1393 name: An optional variable_scope name. 1394 1395 Returns: 1396 value_tensor: A `Tensor` representing the current value of the metric. 1397 update_op: An operation that accumulates the error from a batch of data. 1398 1399 Raises: 1400 ValueError: If `weights` is not `None` and its shape doesn't match `values`, 1401 or if either `metrics_collections` or `updates_collections` are not a list 1402 or tuple. 1403 RuntimeError: If eager execution is enabled. 1404 """ 1405 if context.in_eager_mode(): 1406 raise RuntimeError('tf.metrics.false_negatives is not supported when ' 1407 'eager execution is enabled.') 1408 1409 with variable_scope.variable_scope(name, 'false_negatives', 1410 (predictions, labels, weights)): 1411 1412 predictions, labels, weights = _remove_squeezable_dimensions( 1413 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1414 labels=math_ops.cast(labels, dtype=dtypes.bool), 1415 weights=weights) 1416 is_false_negative = math_ops.logical_and( 1417 math_ops.equal(labels, True), math_ops.equal(predictions, False)) 1418 return _count_condition(is_false_negative, weights, metrics_collections, 1419 updates_collections) 1420 1421 1422 @tf_export('metrics.false_negatives_at_thresholds') 1423 def false_negatives_at_thresholds(labels, 1424 predictions, 1425 thresholds, 1426 weights=None, 1427 metrics_collections=None, 1428 updates_collections=None, 1429 name=None): 1430 """Computes false negatives at provided threshold values. 1431 1432 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1433 1434 Args: 1435 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1436 `bool`. 1437 predictions: A floating point `Tensor` of arbitrary shape and whose values 1438 are in the range `[0, 1]`. 1439 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1440 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1441 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1442 be either `1`, or the same as the corresponding `labels` dimension). 1443 metrics_collections: An optional list of collections that `false_negatives` 1444 should be added to. 1445 updates_collections: An optional list of collections that `update_op` should 1446 be added to. 1447 name: An optional variable_scope name. 1448 1449 Returns: 1450 false_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1451 update_op: An operation that updates the `false_negatives` variable and 1452 returns its current value. 1453 1454 Raises: 1455 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1456 `weights` is not `None` and its shape doesn't match `predictions`, or if 1457 either `metrics_collections` or `updates_collections` are not a list or 1458 tuple. 1459 RuntimeError: If eager execution is enabled. 1460 """ 1461 if context.in_eager_mode(): 1462 raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not ' 1463 'supported when eager execution is enabled.') 1464 1465 with variable_scope.variable_scope(name, 'false_negatives', 1466 (predictions, labels, weights)): 1467 values, update_ops = _confusion_matrix_at_thresholds( 1468 labels, predictions, thresholds, weights=weights, includes=('fn',)) 1469 1470 if metrics_collections: 1471 ops.add_to_collections(metrics_collections, values['fn']) 1472 1473 if updates_collections: 1474 ops.add_to_collections(updates_collections, update_ops['fn']) 1475 1476 return values['fn'], update_ops['fn'] 1477 1478 1479 @tf_export('metrics.false_positives') 1480 def false_positives(labels, 1481 predictions, 1482 weights=None, 1483 metrics_collections=None, 1484 updates_collections=None, 1485 name=None): 1486 """Sum the weights of false positives. 1487 1488 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1489 1490 Args: 1491 labels: The ground truth values, a `Tensor` whose dimensions must match 1492 `predictions`. Will be cast to `bool`. 1493 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1494 be cast to `bool`. 1495 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1496 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1497 be either `1`, or the same as the corresponding `labels` dimension). 1498 metrics_collections: An optional list of collections that the metric 1499 value variable should be added to. 1500 updates_collections: An optional list of collections that the metric update 1501 ops should be added to. 1502 name: An optional variable_scope name. 1503 1504 Returns: 1505 value_tensor: A `Tensor` representing the current value of the metric. 1506 update_op: An operation that accumulates the error from a batch of data. 1507 1508 Raises: 1509 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1510 `weights` is not `None` and its shape doesn't match `predictions`, or if 1511 either `metrics_collections` or `updates_collections` are not a list or 1512 tuple. 1513 RuntimeError: If eager execution is enabled. 1514 """ 1515 if context.in_eager_mode(): 1516 raise RuntimeError('tf.metrics.false_positives is not supported when ' 1517 'eager execution is enabled.') 1518 1519 with variable_scope.variable_scope(name, 'false_positives', 1520 (predictions, labels, weights)): 1521 1522 predictions, labels, weights = _remove_squeezable_dimensions( 1523 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1524 labels=math_ops.cast(labels, dtype=dtypes.bool), 1525 weights=weights) 1526 is_false_positive = math_ops.logical_and( 1527 math_ops.equal(labels, False), math_ops.equal(predictions, True)) 1528 return _count_condition(is_false_positive, weights, metrics_collections, 1529 updates_collections) 1530 1531 1532 @tf_export('metrics.false_positives_at_thresholds') 1533 def false_positives_at_thresholds(labels, 1534 predictions, 1535 thresholds, 1536 weights=None, 1537 metrics_collections=None, 1538 updates_collections=None, 1539 name=None): 1540 """Computes false positives at provided threshold values. 1541 1542 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1543 1544 Args: 1545 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1546 `bool`. 1547 predictions: A floating point `Tensor` of arbitrary shape and whose values 1548 are in the range `[0, 1]`. 1549 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1550 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1551 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1552 be either `1`, or the same as the corresponding `labels` dimension). 1553 metrics_collections: An optional list of collections that `false_positives` 1554 should be added to. 1555 updates_collections: An optional list of collections that `update_op` should 1556 be added to. 1557 name: An optional variable_scope name. 1558 1559 Returns: 1560 false_positives: A float `Tensor` of shape `[len(thresholds)]`. 1561 update_op: An operation that updates the `false_positives` variable and 1562 returns its current value. 1563 1564 Raises: 1565 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1566 `weights` is not `None` and its shape doesn't match `predictions`, or if 1567 either `metrics_collections` or `updates_collections` are not a list or 1568 tuple. 1569 RuntimeError: If eager execution is enabled. 1570 """ 1571 if context.in_eager_mode(): 1572 raise RuntimeError('tf.metrics.false_positives_at_thresholds is not ' 1573 'supported when eager execution is enabled.') 1574 1575 with variable_scope.variable_scope(name, 'false_positives', 1576 (predictions, labels, weights)): 1577 values, update_ops = _confusion_matrix_at_thresholds( 1578 labels, predictions, thresholds, weights=weights, includes=('fp',)) 1579 1580 if metrics_collections: 1581 ops.add_to_collections(metrics_collections, values['fp']) 1582 1583 if updates_collections: 1584 ops.add_to_collections(updates_collections, update_ops['fp']) 1585 1586 return values['fp'], update_ops['fp'] 1587 1588 1589 @tf_export('metrics.true_negatives') 1590 def true_negatives(labels, 1591 predictions, 1592 weights=None, 1593 metrics_collections=None, 1594 updates_collections=None, 1595 name=None): 1596 """Sum the weights of true_negatives. 1597 1598 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1599 1600 Args: 1601 labels: The ground truth values, a `Tensor` whose dimensions must match 1602 `predictions`. Will be cast to `bool`. 1603 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1604 be cast to `bool`. 1605 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1606 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1607 be either `1`, or the same as the corresponding `labels` dimension). 1608 metrics_collections: An optional list of collections that the metric 1609 value variable should be added to. 1610 updates_collections: An optional list of collections that the metric update 1611 ops should be added to. 1612 name: An optional variable_scope name. 1613 1614 Returns: 1615 value_tensor: A `Tensor` representing the current value of the metric. 1616 update_op: An operation that accumulates the error from a batch of data. 1617 1618 Raises: 1619 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1620 `weights` is not `None` and its shape doesn't match `predictions`, or if 1621 either `metrics_collections` or `updates_collections` are not a list or 1622 tuple. 1623 RuntimeError: If eager execution is enabled. 1624 """ 1625 if context.in_eager_mode(): 1626 raise RuntimeError('tf.metrics.true_negatives is not ' 1627 'supported when eager execution is enabled.') 1628 1629 with variable_scope.variable_scope(name, 'true_negatives', 1630 (predictions, labels, weights)): 1631 1632 predictions, labels, weights = _remove_squeezable_dimensions( 1633 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1634 labels=math_ops.cast(labels, dtype=dtypes.bool), 1635 weights=weights) 1636 is_true_negative = math_ops.logical_and( 1637 math_ops.equal(labels, False), math_ops.equal(predictions, False)) 1638 return _count_condition(is_true_negative, weights, metrics_collections, 1639 updates_collections) 1640 1641 1642 @tf_export('metrics.true_negatives_at_thresholds') 1643 def true_negatives_at_thresholds(labels, 1644 predictions, 1645 thresholds, 1646 weights=None, 1647 metrics_collections=None, 1648 updates_collections=None, 1649 name=None): 1650 """Computes true negatives at provided threshold values. 1651 1652 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1653 1654 Args: 1655 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1656 `bool`. 1657 predictions: A floating point `Tensor` of arbitrary shape and whose values 1658 are in the range `[0, 1]`. 1659 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1660 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1661 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1662 be either `1`, or the same as the corresponding `labels` dimension). 1663 metrics_collections: An optional list of collections that `true_negatives` 1664 should be added to. 1665 updates_collections: An optional list of collections that `update_op` should 1666 be added to. 1667 name: An optional variable_scope name. 1668 1669 Returns: 1670 true_negatives: A float `Tensor` of shape `[len(thresholds)]`. 1671 update_op: An operation that updates the `true_negatives` variable and 1672 returns its current value. 1673 1674 Raises: 1675 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1676 `weights` is not `None` and its shape doesn't match `predictions`, or if 1677 either `metrics_collections` or `updates_collections` are not a list or 1678 tuple. 1679 RuntimeError: If eager execution is enabled. 1680 """ 1681 if context.in_eager_mode(): 1682 raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not ' 1683 'supported when eager execution is enabled.') 1684 1685 with variable_scope.variable_scope(name, 'true_negatives', 1686 (predictions, labels, weights)): 1687 values, update_ops = _confusion_matrix_at_thresholds( 1688 labels, predictions, thresholds, weights=weights, includes=('tn',)) 1689 1690 if metrics_collections: 1691 ops.add_to_collections(metrics_collections, values['tn']) 1692 1693 if updates_collections: 1694 ops.add_to_collections(updates_collections, update_ops['tn']) 1695 1696 return values['tn'], update_ops['tn'] 1697 1698 1699 @tf_export('metrics.true_positives') 1700 def true_positives(labels, 1701 predictions, 1702 weights=None, 1703 metrics_collections=None, 1704 updates_collections=None, 1705 name=None): 1706 """Sum the weights of true_positives. 1707 1708 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1709 1710 Args: 1711 labels: The ground truth values, a `Tensor` whose dimensions must match 1712 `predictions`. Will be cast to `bool`. 1713 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1714 be cast to `bool`. 1715 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1716 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1717 be either `1`, or the same as the corresponding `labels` dimension). 1718 metrics_collections: An optional list of collections that the metric 1719 value variable should be added to. 1720 updates_collections: An optional list of collections that the metric update 1721 ops should be added to. 1722 name: An optional variable_scope name. 1723 1724 Returns: 1725 value_tensor: A `Tensor` representing the current value of the metric. 1726 update_op: An operation that accumulates the error from a batch of data. 1727 1728 Raises: 1729 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1730 `weights` is not `None` and its shape doesn't match `predictions`, or if 1731 either `metrics_collections` or `updates_collections` are not a list or 1732 tuple. 1733 RuntimeError: If eager execution is enabled. 1734 """ 1735 if context.in_eager_mode(): 1736 raise RuntimeError('tf.metrics.true_positives is not ' 1737 'supported when eager execution is enabled.') 1738 1739 with variable_scope.variable_scope(name, 'true_positives', 1740 (predictions, labels, weights)): 1741 1742 predictions, labels, weights = _remove_squeezable_dimensions( 1743 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1744 labels=math_ops.cast(labels, dtype=dtypes.bool), 1745 weights=weights) 1746 is_true_positive = math_ops.logical_and( 1747 math_ops.equal(labels, True), math_ops.equal(predictions, True)) 1748 return _count_condition(is_true_positive, weights, metrics_collections, 1749 updates_collections) 1750 1751 1752 @tf_export('metrics.true_positives_at_thresholds') 1753 def true_positives_at_thresholds(labels, 1754 predictions, 1755 thresholds, 1756 weights=None, 1757 metrics_collections=None, 1758 updates_collections=None, 1759 name=None): 1760 """Computes true positives at provided threshold values. 1761 1762 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1763 1764 Args: 1765 labels: A `Tensor` whose shape matches `predictions`. Will be cast to 1766 `bool`. 1767 predictions: A floating point `Tensor` of arbitrary shape and whose values 1768 are in the range `[0, 1]`. 1769 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1770 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1771 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1772 be either `1`, or the same as the corresponding `labels` dimension). 1773 metrics_collections: An optional list of collections that `true_positives` 1774 should be added to. 1775 updates_collections: An optional list of collections that `update_op` should 1776 be added to. 1777 name: An optional variable_scope name. 1778 1779 Returns: 1780 true_positives: A float `Tensor` of shape `[len(thresholds)]`. 1781 update_op: An operation that updates the `true_positives` variable and 1782 returns its current value. 1783 1784 Raises: 1785 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1786 `weights` is not `None` and its shape doesn't match `predictions`, or if 1787 either `metrics_collections` or `updates_collections` are not a list or 1788 tuple. 1789 RuntimeError: If eager execution is enabled. 1790 """ 1791 if context.in_eager_mode(): 1792 raise RuntimeError('tf.metrics.true_positives_at_thresholds is not ' 1793 'supported when eager execution is enabled.') 1794 1795 with variable_scope.variable_scope(name, 'true_positives', 1796 (predictions, labels, weights)): 1797 values, update_ops = _confusion_matrix_at_thresholds( 1798 labels, predictions, thresholds, weights=weights, includes=('tp',)) 1799 1800 if metrics_collections: 1801 ops.add_to_collections(metrics_collections, values['tp']) 1802 1803 if updates_collections: 1804 ops.add_to_collections(updates_collections, update_ops['tp']) 1805 1806 return values['tp'], update_ops['tp'] 1807 1808 1809 @tf_export('metrics.precision') 1810 def precision(labels, 1811 predictions, 1812 weights=None, 1813 metrics_collections=None, 1814 updates_collections=None, 1815 name=None): 1816 """Computes the precision of the predictions with respect to the labels. 1817 1818 The `precision` function creates two local variables, 1819 `true_positives` and `false_positives`, that are used to compute the 1820 precision. This value is ultimately returned as `precision`, an idempotent 1821 operation that simply divides `true_positives` by the sum of `true_positives` 1822 and `false_positives`. 1823 1824 For estimation of the metric over a stream of data, the function creates an 1825 `update_op` operation that updates these variables and returns the 1826 `precision`. `update_op` weights each prediction by the corresponding value in 1827 `weights`. 1828 1829 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1830 1831 Args: 1832 labels: The ground truth values, a `Tensor` whose dimensions must match 1833 `predictions`. Will be cast to `bool`. 1834 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 1835 be cast to `bool`. 1836 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1837 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1838 be either `1`, or the same as the corresponding `labels` dimension). 1839 metrics_collections: An optional list of collections that `precision` should 1840 be added to. 1841 updates_collections: An optional list of collections that `update_op` should 1842 be added to. 1843 name: An optional variable_scope name. 1844 1845 Returns: 1846 precision: Scalar float `Tensor` with the value of `true_positives` 1847 divided by the sum of `true_positives` and `false_positives`. 1848 update_op: `Operation` that increments `true_positives` and 1849 `false_positives` variables appropriately and whose value matches 1850 `precision`. 1851 1852 Raises: 1853 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1854 `weights` is not `None` and its shape doesn't match `predictions`, or if 1855 either `metrics_collections` or `updates_collections` are not a list or 1856 tuple. 1857 RuntimeError: If eager execution is enabled. 1858 """ 1859 if context.in_eager_mode(): 1860 raise RuntimeError('tf.metrics.precision is not ' 1861 'supported when eager execution is enabled.') 1862 1863 with variable_scope.variable_scope(name, 'precision', 1864 (predictions, labels, weights)): 1865 1866 predictions, labels, weights = _remove_squeezable_dimensions( 1867 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 1868 labels=math_ops.cast(labels, dtype=dtypes.bool), 1869 weights=weights) 1870 1871 true_p, true_positives_update_op = true_positives( 1872 labels, 1873 predictions, 1874 weights, 1875 metrics_collections=None, 1876 updates_collections=None, 1877 name=None) 1878 false_p, false_positives_update_op = false_positives( 1879 labels, 1880 predictions, 1881 weights, 1882 metrics_collections=None, 1883 updates_collections=None, 1884 name=None) 1885 1886 def compute_precision(tp, fp, name): 1887 return array_ops.where( 1888 math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name) 1889 1890 p = compute_precision(true_p, false_p, 'value') 1891 update_op = compute_precision(true_positives_update_op, 1892 false_positives_update_op, 'update_op') 1893 1894 if metrics_collections: 1895 ops.add_to_collections(metrics_collections, p) 1896 1897 if updates_collections: 1898 ops.add_to_collections(updates_collections, update_op) 1899 1900 return p, update_op 1901 1902 1903 @tf_export('metrics.precision_at_thresholds') 1904 def precision_at_thresholds(labels, 1905 predictions, 1906 thresholds, 1907 weights=None, 1908 metrics_collections=None, 1909 updates_collections=None, 1910 name=None): 1911 """Computes precision values for different `thresholds` on `predictions`. 1912 1913 The `precision_at_thresholds` function creates four local variables, 1914 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 1915 for various values of thresholds. `precision[i]` is defined as the total 1916 weight of values in `predictions` above `thresholds[i]` whose corresponding 1917 entry in `labels` is `True`, divided by the total weight of values in 1918 `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] + 1919 false_positives[i])`). 1920 1921 For estimation of the metric over a stream of data, the function creates an 1922 `update_op` operation that updates these variables and returns the 1923 `precision`. 1924 1925 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 1926 1927 Args: 1928 labels: The ground truth values, a `Tensor` whose dimensions must match 1929 `predictions`. Will be cast to `bool`. 1930 predictions: A floating point `Tensor` of arbitrary shape and whose values 1931 are in the range `[0, 1]`. 1932 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 1933 weights: Optional `Tensor` whose rank is either 0, or the same rank as 1934 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 1935 be either `1`, or the same as the corresponding `labels` dimension). 1936 metrics_collections: An optional list of collections that `auc` should be 1937 added to. 1938 updates_collections: An optional list of collections that `update_op` should 1939 be added to. 1940 name: An optional variable_scope name. 1941 1942 Returns: 1943 precision: A float `Tensor` of shape `[len(thresholds)]`. 1944 update_op: An operation that increments the `true_positives`, 1945 `true_negatives`, `false_positives` and `false_negatives` variables that 1946 are used in the computation of `precision`. 1947 1948 Raises: 1949 ValueError: If `predictions` and `labels` have mismatched shapes, or if 1950 `weights` is not `None` and its shape doesn't match `predictions`, or if 1951 either `metrics_collections` or `updates_collections` are not a list or 1952 tuple. 1953 RuntimeError: If eager execution is enabled. 1954 """ 1955 if context.in_eager_mode(): 1956 raise RuntimeError('tf.metrics.precision_at_thresholds is not ' 1957 'supported when eager execution is enabled.') 1958 1959 with variable_scope.variable_scope(name, 'precision_at_thresholds', 1960 (predictions, labels, weights)): 1961 values, update_ops = _confusion_matrix_at_thresholds( 1962 labels, predictions, thresholds, weights, includes=('tp', 'fp')) 1963 1964 # Avoid division by zero. 1965 epsilon = 1e-7 1966 1967 def compute_precision(tp, fp, name): 1968 return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name) 1969 1970 prec = compute_precision(values['tp'], values['fp'], 'value') 1971 update_op = compute_precision(update_ops['tp'], update_ops['fp'], 1972 'update_op') 1973 1974 if metrics_collections: 1975 ops.add_to_collections(metrics_collections, prec) 1976 1977 if updates_collections: 1978 ops.add_to_collections(updates_collections, update_op) 1979 1980 return prec, update_op 1981 1982 1983 @tf_export('metrics.recall') 1984 def recall(labels, 1985 predictions, 1986 weights=None, 1987 metrics_collections=None, 1988 updates_collections=None, 1989 name=None): 1990 """Computes the recall of the predictions with respect to the labels. 1991 1992 The `recall` function creates two local variables, `true_positives` 1993 and `false_negatives`, that are used to compute the recall. This value is 1994 ultimately returned as `recall`, an idempotent operation that simply divides 1995 `true_positives` by the sum of `true_positives` and `false_negatives`. 1996 1997 For estimation of the metric over a stream of data, the function creates an 1998 `update_op` that updates these variables and returns the `recall`. `update_op` 1999 weights each prediction by the corresponding value in `weights`. 2000 2001 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2002 2003 Args: 2004 labels: The ground truth values, a `Tensor` whose dimensions must match 2005 `predictions`. Will be cast to `bool`. 2006 predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will 2007 be cast to `bool`. 2008 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2009 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2010 be either `1`, or the same as the corresponding `labels` dimension). 2011 metrics_collections: An optional list of collections that `recall` should 2012 be added to. 2013 updates_collections: An optional list of collections that `update_op` should 2014 be added to. 2015 name: An optional variable_scope name. 2016 2017 Returns: 2018 recall: Scalar float `Tensor` with the value of `true_positives` divided 2019 by the sum of `true_positives` and `false_negatives`. 2020 update_op: `Operation` that increments `true_positives` and 2021 `false_negatives` variables appropriately and whose value matches 2022 `recall`. 2023 2024 Raises: 2025 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2026 `weights` is not `None` and its shape doesn't match `predictions`, or if 2027 either `metrics_collections` or `updates_collections` are not a list or 2028 tuple. 2029 RuntimeError: If eager execution is enabled. 2030 """ 2031 if context.in_eager_mode(): 2032 raise RuntimeError('tf.metrics.recall is not supported is not ' 2033 'supported when eager execution is enabled.') 2034 2035 with variable_scope.variable_scope(name, 'recall', 2036 (predictions, labels, weights)): 2037 predictions, labels, weights = _remove_squeezable_dimensions( 2038 predictions=math_ops.cast(predictions, dtype=dtypes.bool), 2039 labels=math_ops.cast(labels, dtype=dtypes.bool), 2040 weights=weights) 2041 2042 true_p, true_positives_update_op = true_positives( 2043 labels, 2044 predictions, 2045 weights, 2046 metrics_collections=None, 2047 updates_collections=None, 2048 name=None) 2049 false_n, false_negatives_update_op = false_negatives( 2050 labels, 2051 predictions, 2052 weights, 2053 metrics_collections=None, 2054 updates_collections=None, 2055 name=None) 2056 2057 def compute_recall(true_p, false_n, name): 2058 return array_ops.where( 2059 math_ops.greater(true_p + false_n, 0), 2060 math_ops.div(true_p, true_p + false_n), 0, name) 2061 2062 rec = compute_recall(true_p, false_n, 'value') 2063 update_op = compute_recall(true_positives_update_op, 2064 false_negatives_update_op, 'update_op') 2065 2066 if metrics_collections: 2067 ops.add_to_collections(metrics_collections, rec) 2068 2069 if updates_collections: 2070 ops.add_to_collections(updates_collections, update_op) 2071 2072 return rec, update_op 2073 2074 2075 def _at_k_name(name, k=None, class_id=None): 2076 if k is not None: 2077 name = '%s_at_%d' % (name, k) 2078 else: 2079 name = '%s_at_k' % (name) 2080 if class_id is not None: 2081 name = '%s_class%d' % (name, class_id) 2082 return name 2083 2084 2085 def _select_class_id(ids, selected_id): 2086 """Filter all but `selected_id` out of `ids`. 2087 2088 Args: 2089 ids: `int64` `Tensor` or `SparseTensor` of IDs. 2090 selected_id: Int id to select. 2091 2092 Returns: 2093 `SparseTensor` of same dimensions as `ids`. This contains only the entries 2094 equal to `selected_id`. 2095 """ 2096 ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids) 2097 if isinstance(ids, sparse_tensor.SparseTensor): 2098 return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values, 2099 selected_id)) 2100 2101 # TODO(ptucker): Make this more efficient, maybe add a sparse version of 2102 # tf.equal and tf.reduce_any? 2103 2104 # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1. 2105 ids_shape = array_ops.shape(ids, out_type=dtypes.int64) 2106 ids_last_dim = array_ops.size(ids_shape) - 1 2107 filled_selected_id_shape = math_ops.reduced_shape(ids_shape, 2108 array_ops.reshape( 2109 ids_last_dim, [1])) 2110 2111 # Intersect `ids` with the selected ID. 2112 filled_selected_id = array_ops.fill(filled_selected_id_shape, 2113 math_ops.to_int64(selected_id)) 2114 result = sets.set_intersection(filled_selected_id, ids) 2115 return sparse_tensor.SparseTensor( 2116 indices=result.indices, values=result.values, dense_shape=ids_shape) 2117 2118 2119 def _maybe_select_class_id(labels, predictions_idx, selected_id=None): 2120 """If class ID is specified, filter all other classes. 2121 2122 Args: 2123 labels: `int64` `Tensor` or `SparseTensor` with shape 2124 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2125 target classes for the associated prediction. Commonly, N=1 and `labels` 2126 has shape [batch_size, num_labels]. [D1, ... DN] must match 2127 `predictions_idx`. 2128 predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k] 2129 where N >= 1. Commonly, N=1 and `predictions_idx` has shape 2130 [batch size, k]. 2131 selected_id: Int id to select. 2132 2133 Returns: 2134 Tuple of `labels` and `predictions_idx`, possibly with classes removed. 2135 """ 2136 if selected_id is None: 2137 return labels, predictions_idx 2138 return (_select_class_id(labels, selected_id), 2139 _select_class_id(predictions_idx, selected_id)) 2140 2141 2142 def _sparse_true_positive_at_k(labels, 2143 predictions_idx, 2144 class_id=None, 2145 weights=None, 2146 name=None): 2147 """Calculates true positives for recall@k and precision@k. 2148 2149 If `class_id` is specified, calculate binary true positives for `class_id` 2150 only. 2151 If `class_id` is not specified, calculate metrics for `k` predicted vs 2152 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2153 2154 Args: 2155 labels: `int64` `Tensor` or `SparseTensor` with shape 2156 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2157 target classes for the associated prediction. Commonly, N=1 and `labels` 2158 has shape [batch_size, num_labels]. [D1, ... DN] must match 2159 `predictions_idx`. 2160 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2161 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2162 match `labels`. 2163 class_id: Class for which we want binary metrics. 2164 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2165 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2166 dimensions must be either `1`, or the same as the corresponding `labels` 2167 dimension). 2168 name: Name of operation. 2169 2170 Returns: 2171 A [D1, ... DN] `Tensor` of true positive counts. 2172 """ 2173 with ops.name_scope(name, 'true_positives', 2174 (predictions_idx, labels, weights)): 2175 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2176 class_id) 2177 tp = sets.set_size(sets.set_intersection(predictions_idx, labels)) 2178 tp = math_ops.to_double(tp) 2179 if weights is not None: 2180 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2181 weights, tp),)): 2182 weights = math_ops.to_double(weights) 2183 tp = math_ops.multiply(tp, weights) 2184 return tp 2185 2186 2187 def _streaming_sparse_true_positive_at_k(labels, 2188 predictions_idx, 2189 k=None, 2190 class_id=None, 2191 weights=None, 2192 name=None): 2193 """Calculates weighted per step true positives for recall@k and precision@k. 2194 2195 If `class_id` is specified, calculate binary true positives for `class_id` 2196 only. 2197 If `class_id` is not specified, calculate metrics for `k` predicted vs 2198 `n` label classes, where `n` is the 2nd dimension of `labels`. 2199 2200 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2201 2202 Args: 2203 labels: `int64` `Tensor` or `SparseTensor` with shape 2204 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2205 target classes for the associated prediction. Commonly, N=1 and `labels` 2206 has shape [batch_size, num_labels]. [D1, ... DN] must match 2207 `predictions_idx`. 2208 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2209 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2210 match `labels`. 2211 k: Integer, k for @k metric. This is only used for default op name. 2212 class_id: Class for which we want binary metrics. 2213 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2214 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2215 dimensions must be either `1`, or the same as the corresponding `labels` 2216 dimension). 2217 name: Name of new variable, and namespace for other dependent ops. 2218 2219 Returns: 2220 A tuple of `Variable` and update `Operation`. 2221 2222 Raises: 2223 ValueError: If `weights` is not `None` and has an incompatible shape. 2224 """ 2225 with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id), 2226 (predictions_idx, labels, weights)) as scope: 2227 tp = _sparse_true_positive_at_k( 2228 predictions_idx=predictions_idx, 2229 labels=labels, 2230 class_id=class_id, 2231 weights=weights) 2232 batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp)) 2233 2234 var = metric_variable([], dtypes.float64, name=scope) 2235 return var, state_ops.assign_add(var, batch_total_tp, name='update') 2236 2237 2238 def _sparse_false_negative_at_k(labels, 2239 predictions_idx, 2240 class_id=None, 2241 weights=None): 2242 """Calculates false negatives for recall@k. 2243 2244 If `class_id` is specified, calculate binary true positives for `class_id` 2245 only. 2246 If `class_id` is not specified, calculate metrics for `k` predicted vs 2247 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 2248 2249 Args: 2250 labels: `int64` `Tensor` or `SparseTensor` with shape 2251 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2252 target classes for the associated prediction. Commonly, N=1 and `labels` 2253 has shape [batch_size, num_labels]. [D1, ... DN] must match 2254 `predictions_idx`. 2255 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2256 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2257 match `labels`. 2258 class_id: Class for which we want binary metrics. 2259 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2260 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2261 dimensions must be either `1`, or the same as the corresponding `labels` 2262 dimension). 2263 2264 Returns: 2265 A [D1, ... DN] `Tensor` of false negative counts. 2266 """ 2267 with ops.name_scope(None, 'false_negatives', 2268 (predictions_idx, labels, weights)): 2269 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 2270 class_id) 2271 fn = sets.set_size( 2272 sets.set_difference(predictions_idx, labels, aminusb=False)) 2273 fn = math_ops.to_double(fn) 2274 if weights is not None: 2275 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 2276 weights, fn),)): 2277 weights = math_ops.to_double(weights) 2278 fn = math_ops.multiply(fn, weights) 2279 return fn 2280 2281 2282 def _streaming_sparse_false_negative_at_k(labels, 2283 predictions_idx, 2284 k, 2285 class_id=None, 2286 weights=None, 2287 name=None): 2288 """Calculates weighted per step false negatives for recall@k. 2289 2290 If `class_id` is specified, calculate binary true positives for `class_id` 2291 only. 2292 If `class_id` is not specified, calculate metrics for `k` predicted vs 2293 `n` label classes, where `n` is the 2nd dimension of `labels`. 2294 2295 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2296 2297 Args: 2298 labels: `int64` `Tensor` or `SparseTensor` with shape 2299 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2300 target classes for the associated prediction. Commonly, N=1 and `labels` 2301 has shape [batch_size, num_labels]. [D1, ... DN] must match 2302 `predictions_idx`. 2303 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 2304 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 2305 match `labels`. 2306 k: Integer, k for @k metric. This is only used for default op name. 2307 class_id: Class for which we want binary metrics. 2308 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2309 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2310 dimensions must be either `1`, or the same as the corresponding `labels` 2311 dimension). 2312 name: Name of new variable, and namespace for other dependent ops. 2313 2314 Returns: 2315 A tuple of `Variable` and update `Operation`. 2316 2317 Raises: 2318 ValueError: If `weights` is not `None` and has an incompatible shape. 2319 """ 2320 with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id), 2321 (predictions_idx, labels, weights)) as scope: 2322 fn = _sparse_false_negative_at_k( 2323 predictions_idx=predictions_idx, 2324 labels=labels, 2325 class_id=class_id, 2326 weights=weights) 2327 batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn)) 2328 2329 var = metric_variable([], dtypes.float64, name=scope) 2330 return var, state_ops.assign_add(var, batch_total_fn, name='update') 2331 2332 2333 @tf_export('metrics.recall_at_k') 2334 def recall_at_k(labels, 2335 predictions, 2336 k, 2337 class_id=None, 2338 weights=None, 2339 metrics_collections=None, 2340 updates_collections=None, 2341 name=None): 2342 """Computes recall@k of the predictions with respect to sparse labels. 2343 2344 If `class_id` is specified, we calculate recall by considering only the 2345 entries in the batch for which `class_id` is in the label, and computing 2346 the fraction of them for which `class_id` is in the top-k `predictions`. 2347 If `class_id` is not specified, we'll calculate recall as how often on 2348 average a class among the labels of a batch entry is in the top-k 2349 `predictions`. 2350 2351 `sparse_recall_at_k` creates two local variables, 2352 `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute 2353 the recall_at_k frequency. This frequency is ultimately returned as 2354 `recall_at_<k>`: an idempotent operation that simply divides 2355 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 2356 `false_negative_at_<k>`). 2357 2358 For estimation of the metric over a stream of data, the function creates an 2359 `update_op` operation that updates these variables and returns the 2360 `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 2361 indicating the top `k` `predictions`. Set operations applied to `top_k` and 2362 `labels` calculate the true positives and false negatives weighted by 2363 `weights`. Then `update_op` increments `true_positive_at_<k>` and 2364 `false_negative_at_<k>` using these values. 2365 2366 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2367 2368 Args: 2369 labels: `int64` `Tensor` or `SparseTensor` with shape 2370 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2371 num_labels=1. N >= 1 and num_labels is the number of target classes for 2372 the associated prediction. Commonly, N=1 and `labels` has shape 2373 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2374 should be in range [0, num_classes), where num_classes is the last 2375 dimension of `predictions`. Values outside this range always count 2376 towards `false_negative_at_<k>`. 2377 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 2378 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 2379 The final dimension contains the logit values for each class. [D1, ... DN] 2380 must match `labels`. 2381 k: Integer, k for @k metric. 2382 class_id: Integer class ID for which we want binary metrics. This should be 2383 in range [0, num_classes), where num_classes is the last dimension of 2384 `predictions`. If class_id is outside this range, the method returns NAN. 2385 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2386 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2387 dimensions must be either `1`, or the same as the corresponding `labels` 2388 dimension). 2389 metrics_collections: An optional list of collections that values should 2390 be added to. 2391 updates_collections: An optional list of collections that updates should 2392 be added to. 2393 name: Name of new update operation, and namespace for other dependent ops. 2394 2395 Returns: 2396 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2397 by the sum of `true_positives` and `false_negatives`. 2398 update_op: `Operation` that increments `true_positives` and 2399 `false_negatives` variables appropriately, and whose value matches 2400 `recall`. 2401 2402 Raises: 2403 ValueError: If `weights` is not `None` and its shape doesn't match 2404 `predictions`, or if either `metrics_collections` or `updates_collections` 2405 are not a list or tuple. 2406 RuntimeError: If eager execution is enabled. 2407 """ 2408 if context.in_eager_mode(): 2409 raise RuntimeError('tf.metrics.recall_at_k is not ' 2410 'supported when eager execution is enabled.') 2411 2412 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2413 (predictions, labels, weights)) as scope: 2414 _, top_k_idx = nn.top_k(predictions, k) 2415 return recall_at_top_k( 2416 labels=labels, 2417 predictions_idx=top_k_idx, 2418 k=k, 2419 class_id=class_id, 2420 weights=weights, 2421 metrics_collections=metrics_collections, 2422 updates_collections=updates_collections, 2423 name=scope) 2424 2425 2426 @tf_export('metrics.recall_at_top_k') 2427 def recall_at_top_k(labels, 2428 predictions_idx, 2429 k=None, 2430 class_id=None, 2431 weights=None, 2432 metrics_collections=None, 2433 updates_collections=None, 2434 name=None): 2435 """Computes recall@k of top-k predictions with respect to sparse labels. 2436 2437 Differs from `recall_at_k` in that predictions must be in the form of top `k` 2438 class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k` 2439 for more details. 2440 2441 Args: 2442 labels: `int64` `Tensor` or `SparseTensor` with shape 2443 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2444 num_labels=1. N >= 1 and num_labels is the number of target classes for 2445 the associated prediction. Commonly, N=1 and `labels` has shape 2446 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 2447 should be in range [0, num_classes), where num_classes is the last 2448 dimension of `predictions`. Values outside this range always count 2449 towards `false_negative_at_<k>`. 2450 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2451 Commonly, N=1 and predictions has shape [batch size, k]. The final 2452 dimension contains the top `k` predicted class indices. [D1, ... DN] must 2453 match `labels`. 2454 k: Integer, k for @k metric. Only used for the default op name. 2455 class_id: Integer class ID for which we want binary metrics. This should be 2456 in range [0, num_classes), where num_classes is the last dimension of 2457 `predictions`. If class_id is outside this range, the method returns NAN. 2458 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2459 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2460 dimensions must be either `1`, or the same as the corresponding `labels` 2461 dimension). 2462 metrics_collections: An optional list of collections that values should 2463 be added to. 2464 updates_collections: An optional list of collections that updates should 2465 be added to. 2466 name: Name of new update operation, and namespace for other dependent ops. 2467 2468 Returns: 2469 recall: Scalar `float64` `Tensor` with the value of `true_positives` divided 2470 by the sum of `true_positives` and `false_negatives`. 2471 update_op: `Operation` that increments `true_positives` and 2472 `false_negatives` variables appropriately, and whose value matches 2473 `recall`. 2474 2475 Raises: 2476 ValueError: If `weights` is not `None` and its shape doesn't match 2477 `predictions`, or if either `metrics_collections` or `updates_collections` 2478 are not a list or tuple. 2479 """ 2480 with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id), 2481 (predictions_idx, labels, weights)) as scope: 2482 labels = _maybe_expand_labels(labels, predictions_idx) 2483 top_k_idx = math_ops.to_int64(predictions_idx) 2484 tp, tp_update = _streaming_sparse_true_positive_at_k( 2485 predictions_idx=top_k_idx, 2486 labels=labels, 2487 k=k, 2488 class_id=class_id, 2489 weights=weights) 2490 fn, fn_update = _streaming_sparse_false_negative_at_k( 2491 predictions_idx=top_k_idx, 2492 labels=labels, 2493 k=k, 2494 class_id=class_id, 2495 weights=weights) 2496 2497 metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope) 2498 update = math_ops.div( 2499 tp_update, math_ops.add(tp_update, fn_update), name='update') 2500 if metrics_collections: 2501 ops.add_to_collections(metrics_collections, metric) 2502 if updates_collections: 2503 ops.add_to_collections(updates_collections, update) 2504 return metric, update 2505 2506 2507 @tf_export('metrics.recall_at_thresholds') 2508 def recall_at_thresholds(labels, 2509 predictions, 2510 thresholds, 2511 weights=None, 2512 metrics_collections=None, 2513 updates_collections=None, 2514 name=None): 2515 """Computes various recall values for different `thresholds` on `predictions`. 2516 2517 The `recall_at_thresholds` function creates four local variables, 2518 `true_positives`, `true_negatives`, `false_positives` and `false_negatives` 2519 for various values of thresholds. `recall[i]` is defined as the total weight 2520 of values in `predictions` above `thresholds[i]` whose corresponding entry in 2521 `labels` is `True`, divided by the total weight of `True` values in `labels` 2522 (`true_positives[i] / (true_positives[i] + false_negatives[i])`). 2523 2524 For estimation of the metric over a stream of data, the function creates an 2525 `update_op` operation that updates these variables and returns the `recall`. 2526 2527 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2528 2529 Args: 2530 labels: The ground truth values, a `Tensor` whose dimensions must match 2531 `predictions`. Will be cast to `bool`. 2532 predictions: A floating point `Tensor` of arbitrary shape and whose values 2533 are in the range `[0, 1]`. 2534 thresholds: A python list or tuple of float thresholds in `[0, 1]`. 2535 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2536 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2537 be either `1`, or the same as the corresponding `labels` dimension). 2538 metrics_collections: An optional list of collections that `recall` should be 2539 added to. 2540 updates_collections: An optional list of collections that `update_op` should 2541 be added to. 2542 name: An optional variable_scope name. 2543 2544 Returns: 2545 recall: A float `Tensor` of shape `[len(thresholds)]`. 2546 update_op: An operation that increments the `true_positives`, 2547 `true_negatives`, `false_positives` and `false_negatives` variables that 2548 are used in the computation of `recall`. 2549 2550 Raises: 2551 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2552 `weights` is not `None` and its shape doesn't match `predictions`, or if 2553 either `metrics_collections` or `updates_collections` are not a list or 2554 tuple. 2555 RuntimeError: If eager execution is enabled. 2556 """ 2557 if context.in_eager_mode(): 2558 raise RuntimeError('tf.metrics.recall_at_thresholds is not ' 2559 'supported when eager execution is enabled.') 2560 2561 with variable_scope.variable_scope(name, 'recall_at_thresholds', 2562 (predictions, labels, weights)): 2563 values, update_ops = _confusion_matrix_at_thresholds( 2564 labels, predictions, thresholds, weights, includes=('tp', 'fn')) 2565 2566 # Avoid division by zero. 2567 epsilon = 1e-7 2568 2569 def compute_recall(tp, fn, name): 2570 return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name) 2571 2572 rec = compute_recall(values['tp'], values['fn'], 'value') 2573 update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op') 2574 2575 if metrics_collections: 2576 ops.add_to_collections(metrics_collections, rec) 2577 2578 if updates_collections: 2579 ops.add_to_collections(updates_collections, update_op) 2580 2581 return rec, update_op 2582 2583 2584 @tf_export('metrics.root_mean_squared_error') 2585 def root_mean_squared_error(labels, 2586 predictions, 2587 weights=None, 2588 metrics_collections=None, 2589 updates_collections=None, 2590 name=None): 2591 """Computes the root mean squared error between the labels and predictions. 2592 2593 The `root_mean_squared_error` function creates two local variables, 2594 `total` and `count` that are used to compute the root mean squared error. 2595 This average is weighted by `weights`, and it is ultimately returned as 2596 `root_mean_squared_error`: an idempotent operation that takes the square root 2597 of the division of `total` by `count`. 2598 2599 For estimation of the metric over a stream of data, the function creates an 2600 `update_op` operation that updates these variables and returns the 2601 `root_mean_squared_error`. Internally, a `squared_error` operation computes 2602 the element-wise square of the difference between `predictions` and `labels`. 2603 Then `update_op` increments `total` with the reduced sum of the product of 2604 `weights` and `squared_error`, and it increments `count` with the reduced sum 2605 of `weights`. 2606 2607 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2608 2609 Args: 2610 labels: A `Tensor` of the same shape as `predictions`. 2611 predictions: A `Tensor` of arbitrary shape. 2612 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2613 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2614 be either `1`, or the same as the corresponding `labels` dimension). 2615 metrics_collections: An optional list of collections that 2616 `root_mean_squared_error` should be added to. 2617 updates_collections: An optional list of collections that `update_op` should 2618 be added to. 2619 name: An optional variable_scope name. 2620 2621 Returns: 2622 root_mean_squared_error: A `Tensor` representing the current mean, the value 2623 of `total` divided by `count`. 2624 update_op: An operation that increments the `total` and `count` variables 2625 appropriately and whose value matches `root_mean_squared_error`. 2626 2627 Raises: 2628 ValueError: If `predictions` and `labels` have mismatched shapes, or if 2629 `weights` is not `None` and its shape doesn't match `predictions`, or if 2630 either `metrics_collections` or `updates_collections` are not a list or 2631 tuple. 2632 RuntimeError: If eager execution is enabled. 2633 """ 2634 if context.in_eager_mode(): 2635 raise RuntimeError('tf.metrics.root_mean_squared_error is not ' 2636 'supported when eager execution is enabled.') 2637 2638 predictions, labels, weights = _remove_squeezable_dimensions( 2639 predictions=predictions, labels=labels, weights=weights) 2640 mse, update_mse_op = mean_squared_error(labels, predictions, weights, None, 2641 None, name or 2642 'root_mean_squared_error') 2643 2644 rmse = math_ops.sqrt(mse) 2645 update_rmse_op = math_ops.sqrt(update_mse_op) 2646 2647 if metrics_collections: 2648 ops.add_to_collections(metrics_collections, rmse) 2649 2650 if updates_collections: 2651 ops.add_to_collections(updates_collections, update_rmse_op) 2652 2653 return rmse, update_rmse_op 2654 2655 2656 @tf_export('metrics.sensitivity_at_specificity') 2657 def sensitivity_at_specificity(labels, 2658 predictions, 2659 specificity, 2660 weights=None, 2661 num_thresholds=200, 2662 metrics_collections=None, 2663 updates_collections=None, 2664 name=None): 2665 """Computes the specificity at a given sensitivity. 2666 2667 The `sensitivity_at_specificity` function creates four local 2668 variables, `true_positives`, `true_negatives`, `false_positives` and 2669 `false_negatives` that are used to compute the sensitivity at the given 2670 specificity value. The threshold for the given specificity value is computed 2671 and used to evaluate the corresponding sensitivity. 2672 2673 For estimation of the metric over a stream of data, the function creates an 2674 `update_op` operation that updates these variables and returns the 2675 `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`, 2676 `false_positives` and `false_negatives` counts with the weight of each case 2677 found in the `predictions` and `labels`. 2678 2679 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2680 2681 For additional information about specificity and sensitivity, see the 2682 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 2683 2684 Args: 2685 labels: The ground truth values, a `Tensor` whose dimensions must match 2686 `predictions`. Will be cast to `bool`. 2687 predictions: A floating point `Tensor` of arbitrary shape and whose values 2688 are in the range `[0, 1]`. 2689 specificity: A scalar value in range `[0, 1]`. 2690 weights: Optional `Tensor` whose rank is either 0, or the same rank as 2691 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 2692 be either `1`, or the same as the corresponding `labels` dimension). 2693 num_thresholds: The number of thresholds to use for matching the given 2694 specificity. 2695 metrics_collections: An optional list of collections that `sensitivity` 2696 should be added to. 2697 updates_collections: An optional list of collections that `update_op` should 2698 be added to. 2699 name: An optional variable_scope name. 2700 2701 Returns: 2702 sensitivity: A scalar `Tensor` representing the sensitivity at the given 2703 `specificity` value. 2704 update_op: An operation that increments the `true_positives`, 2705 `true_negatives`, `false_positives` and `false_negatives` variables 2706 appropriately and whose value matches `sensitivity`. 2707 2708 Raises: 2709 ValueError: If `predictions` and `labels` have mismatched shapes, if 2710 `weights` is not `None` and its shape doesn't match `predictions`, or if 2711 `specificity` is not between 0 and 1, or if either `metrics_collections` 2712 or `updates_collections` are not a list or tuple. 2713 RuntimeError: If eager execution is enabled. 2714 """ 2715 if context.in_eager_mode(): 2716 raise RuntimeError('tf.metrics.sensitivity_at_specificity is not ' 2717 'supported when eager execution is enabled.') 2718 2719 if specificity < 0 or specificity > 1: 2720 raise ValueError('`specificity` must be in the range [0, 1].') 2721 2722 with variable_scope.variable_scope(name, 'sensitivity_at_specificity', 2723 (predictions, labels, weights)): 2724 kepsilon = 1e-7 # to account for floating point imprecisions 2725 thresholds = [ 2726 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 2727 ] 2728 thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon] 2729 2730 values, update_ops = _confusion_matrix_at_thresholds( 2731 labels, predictions, thresholds, weights) 2732 2733 def compute_sensitivity_at_specificity(tp, tn, fp, fn, name): 2734 specificities = math_ops.div(tn, tn + fp + kepsilon) 2735 tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0) 2736 tf_index = math_ops.cast(tf_index, dtypes.int32) 2737 2738 # Now, we have the implicit threshold, so compute the sensitivity: 2739 return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon, 2740 name) 2741 2742 sensitivity = compute_sensitivity_at_specificity( 2743 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 2744 update_op = compute_sensitivity_at_specificity( 2745 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 2746 'update_op') 2747 2748 if metrics_collections: 2749 ops.add_to_collections(metrics_collections, sensitivity) 2750 2751 if updates_collections: 2752 ops.add_to_collections(updates_collections, update_op) 2753 2754 return sensitivity, update_op 2755 2756 2757 def _expand_and_tile(tensor, multiple, dim=0, name=None): 2758 """Slice `tensor` shape in 2, then tile along the sliced dimension. 2759 2760 A new dimension is inserted in shape of `tensor` before `dim`, then values are 2761 tiled `multiple` times along the new dimension. 2762 2763 Args: 2764 tensor: Input `Tensor` or `SparseTensor`. 2765 multiple: Integer, number of times to tile. 2766 dim: Integer, dimension along which to tile. 2767 name: Name of operation. 2768 2769 Returns: 2770 `Tensor` result of expanding and tiling `tensor`. 2771 2772 Raises: 2773 ValueError: if `multiple` is less than 1, or `dim` is not in 2774 `[-rank(tensor), rank(tensor)]`. 2775 """ 2776 if multiple < 1: 2777 raise ValueError('Invalid multiple %s, must be > 0.' % multiple) 2778 with ops.name_scope(name, 'expand_and_tile', 2779 (tensor, multiple, dim)) as scope: 2780 # Sparse. 2781 tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor) 2782 if isinstance(tensor, sparse_tensor.SparseTensor): 2783 if dim < 0: 2784 expand_dims = array_ops.reshape( 2785 array_ops.size(tensor.dense_shape) + dim, [1]) 2786 else: 2787 expand_dims = [dim] 2788 expanded_shape = array_ops.concat( 2789 (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1], 2790 array_ops.slice(tensor.dense_shape, expand_dims, [-1])), 2791 0, 2792 name='expanded_shape') 2793 expanded = sparse_ops.sparse_reshape( 2794 tensor, shape=expanded_shape, name='expand') 2795 if multiple == 1: 2796 return expanded 2797 return sparse_ops.sparse_concat( 2798 dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope) 2799 2800 # Dense. 2801 expanded = array_ops.expand_dims( 2802 tensor, dim if (dim >= 0) else (dim - 1), name='expand') 2803 if multiple == 1: 2804 return expanded 2805 ones = array_ops.ones_like(array_ops.shape(tensor)) 2806 tile_multiples = array_ops.concat( 2807 (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples') 2808 return array_ops.tile(expanded, tile_multiples, name=scope) 2809 2810 2811 def _num_relevant(labels, k): 2812 """Computes number of relevant values for each row in labels. 2813 2814 For labels with shape [D1, ... DN, num_labels], this is the minimum of 2815 `num_labels` and `k`. 2816 2817 Args: 2818 labels: `int64` `Tensor` or `SparseTensor` with shape 2819 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 2820 target classes for the associated prediction. Commonly, N=1 and `labels` 2821 has shape [batch_size, num_labels]. 2822 k: Integer, k for @k metric. 2823 2824 Returns: 2825 Integer `Tensor` of shape [D1, ... DN], where each value is the number of 2826 relevant values for that row. 2827 2828 Raises: 2829 ValueError: if inputs have invalid dtypes or values. 2830 """ 2831 if k < 1: 2832 raise ValueError('Invalid k=%s.' % k) 2833 with ops.name_scope(None, 'num_relevant', (labels,)) as scope: 2834 # For SparseTensor, calculate separate count for each row. 2835 labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels) 2836 if isinstance(labels, sparse_tensor.SparseTensor): 2837 return math_ops.minimum(sets.set_size(labels), k, name=scope) 2838 2839 # For dense Tensor, calculate scalar count based on last dimension, and 2840 # tile across labels shape. 2841 labels_shape = array_ops.shape(labels) 2842 labels_size = labels_shape[-1] 2843 num_relevant_scalar = math_ops.minimum(labels_size, k) 2844 return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope) 2845 2846 2847 def _sparse_average_precision_at_top_k(labels, predictions_idx): 2848 """Computes average precision@k of predictions with respect to sparse labels. 2849 2850 From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula 2851 for each row is: 2852 2853 AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items 2854 2855 A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`, 2856 `labels`, and the result `Tensors`. In the common case, this is [batch_size]. 2857 Each row of the results contains the average precision for that row. 2858 2859 Args: 2860 labels: `int64` `Tensor` or `SparseTensor` with shape 2861 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2862 num_labels=1. N >= 1 and num_labels is the number of target classes for 2863 the associated prediction. Commonly, N=1 and `labels` has shape 2864 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 2865 Values should be in range [0, num_classes). 2866 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2867 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 2868 dimension must be set and contains the top `k` predicted class indices. 2869 [D1, ... DN] must match `labels`. Values should be in range 2870 [0, num_classes). 2871 2872 Returns: 2873 `float64` `Tensor` of shape [D1, ... DN], where each value is the average 2874 precision for that row. 2875 2876 Raises: 2877 ValueError: if the last dimension of predictions_idx is not set. 2878 """ 2879 with ops.name_scope(None, 'average_precision', 2880 (predictions_idx, labels)) as scope: 2881 predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx') 2882 if predictions_idx.get_shape().ndims == 0: 2883 raise ValueError('The rank of predictions_idx must be at least 1.') 2884 k = predictions_idx.get_shape().as_list()[-1] 2885 if k is None: 2886 raise ValueError('The last dimension of predictions_idx must be set.') 2887 labels = _maybe_expand_labels(labels, predictions_idx) 2888 2889 # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate 2890 # prediction for each k, so we can calculate separate true positive values 2891 # for each k. 2892 predictions_idx_per_k = array_ops.expand_dims( 2893 predictions_idx, -1, name='predictions_idx_per_k') 2894 2895 # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor. 2896 labels_per_k = _expand_and_tile( 2897 labels, multiple=k, dim=-1, name='labels_per_k') 2898 2899 # The following tensors are all of shape [D1, ... DN, k], containing values 2900 # per row, per k value. 2901 # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at 2902 # that k value is correct, 0 otherwise. This is the "rel_{i}" term from 2903 # the formula above. 2904 # `tp_per_k` (int32) - True positive counts. 2905 # `retrieved_per_k` (int32) - Number of predicted values at each k. This is 2906 # the precision denominator. 2907 # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}" 2908 # term from the formula above. 2909 # `relevant_precision_per_k` (float64) - Relevant precisions; i.e., 2910 # precisions at all k for which relevance indicator is true. 2911 relevant_per_k = _sparse_true_positive_at_k( 2912 labels_per_k, predictions_idx_per_k, name='relevant_per_k') 2913 tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k') 2914 retrieved_per_k = math_ops.cumsum( 2915 array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k') 2916 precision_per_k = math_ops.div( 2917 math_ops.to_double(tp_per_k), 2918 math_ops.to_double(retrieved_per_k), 2919 name='precision_per_k') 2920 relevant_precision_per_k = math_ops.multiply( 2921 precision_per_k, 2922 math_ops.to_double(relevant_per_k), 2923 name='relevant_precision_per_k') 2924 2925 # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor. 2926 precision_sum = math_ops.reduce_sum( 2927 relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum') 2928 2929 # Divide by number of relevant items to get average precision. These are 2930 # the "num_relevant_items" and "AveP" terms from the formula above. 2931 num_relevant_items = math_ops.to_double(_num_relevant(labels, k)) 2932 return math_ops.div(precision_sum, num_relevant_items, name=scope) 2933 2934 2935 def _streaming_sparse_average_precision_at_top_k(labels, 2936 predictions_idx, 2937 weights=None, 2938 metrics_collections=None, 2939 updates_collections=None, 2940 name=None): 2941 """Computes average precision@k of predictions with respect to sparse labels. 2942 2943 `sparse_average_precision_at_top_k` creates two local variables, 2944 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 2945 are used to compute the frequency. This frequency is ultimately returned as 2946 `average_precision_at_<k>`: an idempotent operation that simply divides 2947 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 2948 2949 For estimation of the metric over a stream of data, the function creates an 2950 `update_op` operation that updates these variables and returns the 2951 `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate 2952 the true positives and false positives weighted by `weights`. Then `update_op` 2953 increments `true_positive_at_<k>` and `false_positive_at_<k>` using these 2954 values. 2955 2956 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 2957 2958 Args: 2959 labels: `int64` `Tensor` or `SparseTensor` with shape 2960 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 2961 num_labels=1. N >= 1 and num_labels is the number of target classes for 2962 the associated prediction. Commonly, N=1 and `labels` has shape 2963 [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`. 2964 Values should be in range [0, num_classes). 2965 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1. 2966 Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final 2967 dimension contains the top `k` predicted class indices. [D1, ... DN] must 2968 match `labels`. Values should be in range [0, num_classes). 2969 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 2970 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 2971 dimensions must be either `1`, or the same as the corresponding `labels` 2972 dimension). 2973 metrics_collections: An optional list of collections that values should 2974 be added to. 2975 updates_collections: An optional list of collections that updates should 2976 be added to. 2977 name: Name of new update operation, and namespace for other dependent ops. 2978 2979 Returns: 2980 mean_average_precision: Scalar `float64` `Tensor` with the mean average 2981 precision values. 2982 update: `Operation` that increments variables appropriately, and whose 2983 value matches `metric`. 2984 """ 2985 with ops.name_scope(name, 'average_precision_at_top_k', 2986 (predictions_idx, labels, weights)) as scope: 2987 # Calculate per-example average precision, and apply weights. 2988 average_precision = _sparse_average_precision_at_top_k( 2989 predictions_idx=predictions_idx, labels=labels) 2990 if weights is not None: 2991 weights = weights_broadcast_ops.broadcast_weights( 2992 math_ops.to_double(weights), average_precision) 2993 average_precision = math_ops.multiply(average_precision, weights) 2994 2995 # Create accumulation variables and update ops for max average precision and 2996 # total average precision. 2997 with ops.name_scope(None, 'max', (average_precision,)) as max_scope: 2998 # `max` is the max possible precision. Since max for any row is 1.0: 2999 # - For the unweighted case, this is just the number of rows. 3000 # - For the weighted case, it's the sum of the weights broadcast across 3001 # `average_precision` rows. 3002 max_var = metric_variable([], dtypes.float64, name=max_scope) 3003 if weights is None: 3004 batch_max = math_ops.to_double( 3005 array_ops.size(average_precision, name='batch_max')) 3006 else: 3007 batch_max = math_ops.reduce_sum(weights, name='batch_max') 3008 max_update = state_ops.assign_add(max_var, batch_max, name='update') 3009 with ops.name_scope(None, 'total', (average_precision,)) as total_scope: 3010 total_var = metric_variable([], dtypes.float64, name=total_scope) 3011 batch_total = math_ops.reduce_sum(average_precision, name='batch_total') 3012 total_update = state_ops.assign_add(total_var, batch_total, name='update') 3013 3014 # Divide total by max to get mean, for both vars and the update ops. 3015 mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean') 3016 update = _safe_scalar_div(total_update, max_update, name=scope) 3017 3018 if metrics_collections: 3019 ops.add_to_collections(metrics_collections, mean_average_precision) 3020 if updates_collections: 3021 ops.add_to_collections(updates_collections, update) 3022 3023 return mean_average_precision, update 3024 3025 3026 @tf_export('metrics.sparse_average_precision_at_k') 3027 @deprecated(None, 'Use average_precision_at_k instead') 3028 def sparse_average_precision_at_k(labels, 3029 predictions, 3030 k, 3031 weights=None, 3032 metrics_collections=None, 3033 updates_collections=None, 3034 name=None): 3035 """Renamed to `average_precision_at_k`, please use that method instead.""" 3036 return average_precision_at_k( 3037 labels=labels, 3038 predictions=predictions, 3039 k=k, 3040 weights=weights, 3041 metrics_collections=metrics_collections, 3042 updates_collections=updates_collections, 3043 name=name) 3044 3045 3046 @tf_export('metrics.average_precision_at_k') 3047 def average_precision_at_k(labels, 3048 predictions, 3049 k, 3050 weights=None, 3051 metrics_collections=None, 3052 updates_collections=None, 3053 name=None): 3054 """Computes average precision@k of predictions with respect to sparse labels. 3055 3056 `average_precision_at_k` creates two local variables, 3057 `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that 3058 are used to compute the frequency. This frequency is ultimately returned as 3059 `average_precision_at_<k>`: an idempotent operation that simply divides 3060 `average_precision_at_<k>/total` by `average_precision_at_<k>/max`. 3061 3062 For estimation of the metric over a stream of data, the function creates an 3063 `update_op` operation that updates these variables and returns the 3064 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3065 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3066 `labels` calculate the true positives and false positives weighted by 3067 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3068 `false_positive_at_<k>` using these values. 3069 3070 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3071 3072 Args: 3073 labels: `int64` `Tensor` or `SparseTensor` with shape 3074 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3075 num_labels=1. N >= 1 and num_labels is the number of target classes for 3076 the associated prediction. Commonly, N=1 and `labels` has shape 3077 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3078 should be in range [0, num_classes), where num_classes is the last 3079 dimension of `predictions`. Values outside this range are ignored. 3080 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3081 N >= 1. Commonly, N=1 and `predictions` has shape 3082 [batch size, num_classes]. The final dimension contains the logit values 3083 for each class. [D1, ... DN] must match `labels`. 3084 k: Integer, k for @k metric. This will calculate an average precision for 3085 range `[1,k]`, as documented above. 3086 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3087 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3088 dimensions must be either `1`, or the same as the corresponding `labels` 3089 dimension). 3090 metrics_collections: An optional list of collections that values should 3091 be added to. 3092 updates_collections: An optional list of collections that updates should 3093 be added to. 3094 name: Name of new update operation, and namespace for other dependent ops. 3095 3096 Returns: 3097 mean_average_precision: Scalar `float64` `Tensor` with the mean average 3098 precision values. 3099 update: `Operation` that increments variables appropriately, and whose 3100 value matches `metric`. 3101 3102 Raises: 3103 ValueError: if k is invalid. 3104 RuntimeError: If eager execution is enabled. 3105 """ 3106 if context.in_eager_mode(): 3107 raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not ' 3108 'supported when eager execution is enabled.') 3109 3110 if k < 1: 3111 raise ValueError('Invalid k=%s.' % k) 3112 with ops.name_scope(name, _at_k_name('average_precision', k), 3113 (predictions, labels, weights)) as scope: 3114 # Calculate top k indices to produce [D1, ... DN, k] tensor. 3115 _, predictions_idx = nn.top_k(predictions, k) 3116 return _streaming_sparse_average_precision_at_top_k( 3117 labels=labels, 3118 predictions_idx=predictions_idx, 3119 weights=weights, 3120 metrics_collections=metrics_collections, 3121 updates_collections=updates_collections, 3122 name=scope) 3123 3124 3125 def _sparse_false_positive_at_k(labels, 3126 predictions_idx, 3127 class_id=None, 3128 weights=None): 3129 """Calculates false positives for precision@k. 3130 3131 If `class_id` is specified, calculate binary true positives for `class_id` 3132 only. 3133 If `class_id` is not specified, calculate metrics for `k` predicted vs 3134 `n` label classes, where `n` is the 2nd dimension of `labels_sparse`. 3135 3136 Args: 3137 labels: `int64` `Tensor` or `SparseTensor` with shape 3138 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3139 target classes for the associated prediction. Commonly, N=1 and `labels` 3140 has shape [batch_size, num_labels]. [D1, ... DN] must match 3141 `predictions_idx`. 3142 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3143 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3144 match `labels`. 3145 class_id: Class for which we want binary metrics. 3146 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3147 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3148 dimensions must be either `1`, or the same as the corresponding `labels` 3149 dimension). 3150 3151 Returns: 3152 A [D1, ... DN] `Tensor` of false positive counts. 3153 """ 3154 with ops.name_scope(None, 'false_positives', 3155 (predictions_idx, labels, weights)): 3156 labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx, 3157 class_id) 3158 fp = sets.set_size( 3159 sets.set_difference(predictions_idx, labels, aminusb=True)) 3160 fp = math_ops.to_double(fp) 3161 if weights is not None: 3162 with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable( 3163 weights, fp),)): 3164 weights = math_ops.to_double(weights) 3165 fp = math_ops.multiply(fp, weights) 3166 return fp 3167 3168 3169 def _streaming_sparse_false_positive_at_k(labels, 3170 predictions_idx, 3171 k=None, 3172 class_id=None, 3173 weights=None, 3174 name=None): 3175 """Calculates weighted per step false positives for precision@k. 3176 3177 If `class_id` is specified, calculate binary true positives for `class_id` 3178 only. 3179 If `class_id` is not specified, calculate metrics for `k` predicted vs 3180 `n` label classes, where `n` is the 2nd dimension of `labels`. 3181 3182 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3183 3184 Args: 3185 labels: `int64` `Tensor` or `SparseTensor` with shape 3186 [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of 3187 target classes for the associated prediction. Commonly, N=1 and `labels` 3188 has shape [batch_size, num_labels]. [D1, ... DN] must match 3189 `predictions_idx`. 3190 predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`, 3191 top `k` predicted classes. For rank `n`, the first `n-1` dimensions must 3192 match `labels`. 3193 k: Integer, k for @k metric. This is only used for default op name. 3194 class_id: Class for which we want binary metrics. 3195 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3196 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3197 dimensions must be either `1`, or the same as the corresponding `labels` 3198 dimension). 3199 name: Name of new variable, and namespace for other dependent ops. 3200 3201 Returns: 3202 A tuple of `Variable` and update `Operation`. 3203 3204 Raises: 3205 ValueError: If `weights` is not `None` and has an incompatible shape. 3206 """ 3207 with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id), 3208 (predictions_idx, labels, weights)) as scope: 3209 fp = _sparse_false_positive_at_k( 3210 predictions_idx=predictions_idx, 3211 labels=labels, 3212 class_id=class_id, 3213 weights=weights) 3214 batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp)) 3215 3216 var = metric_variable([], dtypes.float64, name=scope) 3217 return var, state_ops.assign_add(var, batch_total_fp, name='update') 3218 3219 3220 @tf_export('metrics.precision_at_top_k') 3221 def precision_at_top_k(labels, 3222 predictions_idx, 3223 k=None, 3224 class_id=None, 3225 weights=None, 3226 metrics_collections=None, 3227 updates_collections=None, 3228 name=None): 3229 """Computes precision@k of the predictions with respect to sparse labels. 3230 3231 Differs from `sparse_precision_at_k` in that predictions must be in the form 3232 of top `k` class indices, whereas `sparse_precision_at_k` expects logits. 3233 Refer to `sparse_precision_at_k` for more details. 3234 3235 Args: 3236 labels: `int64` `Tensor` or `SparseTensor` with shape 3237 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3238 num_labels=1. N >= 1 and num_labels is the number of target classes for 3239 the associated prediction. Commonly, N=1 and `labels` has shape 3240 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3241 should be in range [0, num_classes), where num_classes is the last 3242 dimension of `predictions`. Values outside this range are ignored. 3243 predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where 3244 N >= 1. Commonly, N=1 and predictions has shape [batch size, k]. 3245 The final dimension contains the top `k` predicted class indices. 3246 [D1, ... DN] must match `labels`. 3247 k: Integer, k for @k metric. Only used for the default op name. 3248 class_id: Integer class ID for which we want binary metrics. This should be 3249 in range [0, num_classes], where num_classes is the last dimension of 3250 `predictions`. If `class_id` is outside this range, the method returns 3251 NAN. 3252 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3253 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3254 dimensions must be either `1`, or the same as the corresponding `labels` 3255 dimension). 3256 metrics_collections: An optional list of collections that values should 3257 be added to. 3258 updates_collections: An optional list of collections that updates should 3259 be added to. 3260 name: Name of new update operation, and namespace for other dependent ops. 3261 3262 Returns: 3263 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3264 divided by the sum of `true_positives` and `false_positives`. 3265 update_op: `Operation` that increments `true_positives` and 3266 `false_positives` variables appropriately, and whose value matches 3267 `precision`. 3268 3269 Raises: 3270 ValueError: If `weights` is not `None` and its shape doesn't match 3271 `predictions`, or if either `metrics_collections` or `updates_collections` 3272 are not a list or tuple. 3273 RuntimeError: If eager execution is enabled. 3274 """ 3275 if context.in_eager_mode(): 3276 raise RuntimeError('tf.metrics.precision_at_top_k is not ' 3277 'supported when eager execution is enabled.') 3278 3279 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3280 (predictions_idx, labels, weights)) as scope: 3281 labels = _maybe_expand_labels(labels, predictions_idx) 3282 top_k_idx = math_ops.to_int64(predictions_idx) 3283 tp, tp_update = _streaming_sparse_true_positive_at_k( 3284 predictions_idx=top_k_idx, 3285 labels=labels, 3286 k=k, 3287 class_id=class_id, 3288 weights=weights) 3289 fp, fp_update = _streaming_sparse_false_positive_at_k( 3290 predictions_idx=top_k_idx, 3291 labels=labels, 3292 k=k, 3293 class_id=class_id, 3294 weights=weights) 3295 3296 metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope) 3297 update = math_ops.div( 3298 tp_update, math_ops.add(tp_update, fp_update), name='update') 3299 if metrics_collections: 3300 ops.add_to_collections(metrics_collections, metric) 3301 if updates_collections: 3302 ops.add_to_collections(updates_collections, update) 3303 return metric, update 3304 3305 3306 @tf_export('metrics.sparse_precision_at_k') 3307 @deprecated(None, 'Use precision_at_k instead') 3308 def sparse_precision_at_k(labels, 3309 predictions, 3310 k, 3311 class_id=None, 3312 weights=None, 3313 metrics_collections=None, 3314 updates_collections=None, 3315 name=None): 3316 """Renamed to `precision_at_k`, please use that method instead.""" 3317 return precision_at_k( 3318 labels=labels, 3319 predictions=predictions, 3320 k=k, 3321 class_id=class_id, 3322 weights=weights, 3323 metrics_collections=metrics_collections, 3324 updates_collections=updates_collections, 3325 name=name) 3326 3327 3328 @tf_export('metrics.precision_at_k') 3329 def precision_at_k(labels, 3330 predictions, 3331 k, 3332 class_id=None, 3333 weights=None, 3334 metrics_collections=None, 3335 updates_collections=None, 3336 name=None): 3337 """Computes precision@k of the predictions with respect to sparse labels. 3338 3339 If `class_id` is specified, we calculate precision by considering only the 3340 entries in the batch for which `class_id` is in the top-k highest 3341 `predictions`, and computing the fraction of them for which `class_id` is 3342 indeed a correct label. 3343 If `class_id` is not specified, we'll calculate precision as how often on 3344 average a class among the top-k classes with the highest predicted values 3345 of a batch entry is correct and can be found in the label for that entry. 3346 3347 `precision_at_k` creates two local variables, 3348 `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute 3349 the precision@k frequency. This frequency is ultimately returned as 3350 `precision_at_<k>`: an idempotent operation that simply divides 3351 `true_positive_at_<k>` by total (`true_positive_at_<k>` + 3352 `false_positive_at_<k>`). 3353 3354 For estimation of the metric over a stream of data, the function creates an 3355 `update_op` operation that updates these variables and returns the 3356 `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor` 3357 indicating the top `k` `predictions`. Set operations applied to `top_k` and 3358 `labels` calculate the true positives and false positives weighted by 3359 `weights`. Then `update_op` increments `true_positive_at_<k>` and 3360 `false_positive_at_<k>` using these values. 3361 3362 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3363 3364 Args: 3365 labels: `int64` `Tensor` or `SparseTensor` with shape 3366 [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies 3367 num_labels=1. N >= 1 and num_labels is the number of target classes for 3368 the associated prediction. Commonly, N=1 and `labels` has shape 3369 [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values 3370 should be in range [0, num_classes), where num_classes is the last 3371 dimension of `predictions`. Values outside this range are ignored. 3372 predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where 3373 N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes]. 3374 The final dimension contains the logit values for each class. [D1, ... DN] 3375 must match `labels`. 3376 k: Integer, k for @k metric. 3377 class_id: Integer class ID for which we want binary metrics. This should be 3378 in range [0, num_classes], where num_classes is the last dimension of 3379 `predictions`. If `class_id` is outside this range, the method returns 3380 NAN. 3381 weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of 3382 `labels`. If the latter, it must be broadcastable to `labels` (i.e., all 3383 dimensions must be either `1`, or the same as the corresponding `labels` 3384 dimension). 3385 metrics_collections: An optional list of collections that values should 3386 be added to. 3387 updates_collections: An optional list of collections that updates should 3388 be added to. 3389 name: Name of new update operation, and namespace for other dependent ops. 3390 3391 Returns: 3392 precision: Scalar `float64` `Tensor` with the value of `true_positives` 3393 divided by the sum of `true_positives` and `false_positives`. 3394 update_op: `Operation` that increments `true_positives` and 3395 `false_positives` variables appropriately, and whose value matches 3396 `precision`. 3397 3398 Raises: 3399 ValueError: If `weights` is not `None` and its shape doesn't match 3400 `predictions`, or if either `metrics_collections` or `updates_collections` 3401 are not a list or tuple. 3402 RuntimeError: If eager execution is enabled. 3403 """ 3404 if context.in_eager_mode(): 3405 raise RuntimeError('tf.metrics.sparse_precision_at_k is not ' 3406 'supported when eager execution is enabled.') 3407 3408 with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id), 3409 (predictions, labels, weights)) as scope: 3410 _, top_k_idx = nn.top_k(predictions, k) 3411 return precision_at_top_k( 3412 labels=labels, 3413 predictions_idx=top_k_idx, 3414 k=k, 3415 class_id=class_id, 3416 weights=weights, 3417 metrics_collections=metrics_collections, 3418 updates_collections=updates_collections, 3419 name=scope) 3420 3421 3422 @tf_export('metrics.specificity_at_sensitivity') 3423 def specificity_at_sensitivity(labels, 3424 predictions, 3425 sensitivity, 3426 weights=None, 3427 num_thresholds=200, 3428 metrics_collections=None, 3429 updates_collections=None, 3430 name=None): 3431 """Computes the specificity at a given sensitivity. 3432 3433 The `specificity_at_sensitivity` function creates four local 3434 variables, `true_positives`, `true_negatives`, `false_positives` and 3435 `false_negatives` that are used to compute the specificity at the given 3436 sensitivity value. The threshold for the given sensitivity value is computed 3437 and used to evaluate the corresponding specificity. 3438 3439 For estimation of the metric over a stream of data, the function creates an 3440 `update_op` operation that updates these variables and returns the 3441 `specificity`. `update_op` increments the `true_positives`, `true_negatives`, 3442 `false_positives` and `false_negatives` counts with the weight of each case 3443 found in the `predictions` and `labels`. 3444 3445 If `weights` is `None`, weights default to 1. Use weights of 0 to mask values. 3446 3447 For additional information about specificity and sensitivity, see the 3448 following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity 3449 3450 Args: 3451 labels: The ground truth values, a `Tensor` whose dimensions must match 3452 `predictions`. Will be cast to `bool`. 3453 predictions: A floating point `Tensor` of arbitrary shape and whose values 3454 are in the range `[0, 1]`. 3455 sensitivity: A scalar value in range `[0, 1]`. 3456 weights: Optional `Tensor` whose rank is either 0, or the same rank as 3457 `labels`, and must be broadcastable to `labels` (i.e., all dimensions must 3458 be either `1`, or the same as the corresponding `labels` dimension). 3459 num_thresholds: The number of thresholds to use for matching the given 3460 sensitivity. 3461 metrics_collections: An optional list of collections that `specificity` 3462 should be added to. 3463 updates_collections: An optional list of collections that `update_op` should 3464 be added to. 3465 name: An optional variable_scope name. 3466 3467 Returns: 3468 specificity: A scalar `Tensor` representing the specificity at the given 3469 `specificity` value. 3470 update_op: An operation that increments the `true_positives`, 3471 `true_negatives`, `false_positives` and `false_negatives` variables 3472 appropriately and whose value matches `specificity`. 3473 3474 Raises: 3475 ValueError: If `predictions` and `labels` have mismatched shapes, if 3476 `weights` is not `None` and its shape doesn't match `predictions`, or if 3477 `sensitivity` is not between 0 and 1, or if either `metrics_collections` 3478 or `updates_collections` are not a list or tuple. 3479 RuntimeError: If eager execution is enabled. 3480 """ 3481 if context.in_eager_mode(): 3482 raise RuntimeError('tf.metrics.specificity_at_sensitivity is not ' 3483 'supported when eager execution is enabled.') 3484 3485 if sensitivity < 0 or sensitivity > 1: 3486 raise ValueError('`sensitivity` must be in the range [0, 1].') 3487 3488 with variable_scope.variable_scope(name, 'specificity_at_sensitivity', 3489 (predictions, labels, weights)): 3490 kepsilon = 1e-7 # to account for floating point imprecisions 3491 thresholds = [ 3492 (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2) 3493 ] 3494 thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon] 3495 3496 values, update_ops = _confusion_matrix_at_thresholds( 3497 labels, predictions, thresholds, weights) 3498 3499 def compute_specificity_at_sensitivity(tp, tn, fp, fn, name): 3500 """Computes the specificity at the given sensitivity. 3501 3502 Args: 3503 tp: True positives. 3504 tn: True negatives. 3505 fp: False positives. 3506 fn: False negatives. 3507 name: The name of the operation. 3508 3509 Returns: 3510 The specificity using the aggregated values. 3511 """ 3512 sensitivities = math_ops.div(tp, tp + fn + kepsilon) 3513 3514 # We'll need to use this trick until tf.argmax allows us to specify 3515 # whether we should use the first or last index in case of ties. 3516 min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity)) 3517 indices_at_minval = math_ops.equal( 3518 math_ops.abs(sensitivities - sensitivity), min_val) 3519 indices_at_minval = math_ops.to_int64(indices_at_minval) 3520 indices_at_minval = math_ops.cumsum(indices_at_minval) 3521 tf_index = math_ops.argmax(indices_at_minval, 0) 3522 tf_index = math_ops.cast(tf_index, dtypes.int32) 3523 3524 # Now, we have the implicit threshold, so compute the specificity: 3525 return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon, 3526 name) 3527 3528 specificity = compute_specificity_at_sensitivity( 3529 values['tp'], values['tn'], values['fp'], values['fn'], 'value') 3530 update_op = compute_specificity_at_sensitivity( 3531 update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'], 3532 'update_op') 3533 3534 if metrics_collections: 3535 ops.add_to_collections(metrics_collections, specificity) 3536 3537 if updates_collections: 3538 ops.add_to_collections(updates_collections, update_op) 3539 3540 return specificity, update_op 3541