Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Implementation of tf.metrics module."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.eager import context
     22 from tensorflow.python.framework import dtypes
     23 from tensorflow.python.framework import ops
     24 from tensorflow.python.framework import sparse_tensor
     25 from tensorflow.python.ops import array_ops
     26 from tensorflow.python.ops import check_ops
     27 from tensorflow.python.ops import confusion_matrix
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import nn
     31 from tensorflow.python.ops import sets
     32 from tensorflow.python.ops import sparse_ops
     33 from tensorflow.python.ops import state_ops
     34 from tensorflow.python.ops import variable_scope
     35 from tensorflow.python.ops import weights_broadcast_ops
     36 from tensorflow.python.util.deprecation import deprecated
     37 from tensorflow.python.util.tf_export import tf_export
     38 
     39 
     40 def metric_variable(shape, dtype, validate_shape=True, name=None):
     41   """Create variable in `GraphKeys.(LOCAL|METRIC_VARIABLES`) collections."""
     42 
     43   return variable_scope.variable(
     44       lambda: array_ops.zeros(shape, dtype),
     45       trainable=False,
     46       collections=[
     47           ops.GraphKeys.LOCAL_VARIABLES, ops.GraphKeys.METRIC_VARIABLES
     48       ],
     49       validate_shape=validate_shape,
     50       name=name)
     51 
     52 
     53 def _remove_squeezable_dimensions(predictions, labels, weights):
     54   """Squeeze or expand last dim if needed.
     55 
     56   Squeezes last dim of `predictions` or `labels` if their rank differs by 1
     57   (using confusion_matrix.remove_squeezable_dimensions).
     58   Squeezes or expands last dim of `weights` if its rank differs by 1 from the
     59   new rank of `predictions`.
     60 
     61   If `weights` is scalar, it is kept scalar.
     62 
     63   This will use static shape if available. Otherwise, it will add graph
     64   operations, which could result in a performance hit.
     65 
     66   Args:
     67     predictions: Predicted values, a `Tensor` of arbitrary dimensions.
     68     labels: Optional label `Tensor` whose dimensions match `predictions`.
     69     weights: Optional weight scalar or `Tensor` whose dimensions match
     70       `predictions`.
     71 
     72   Returns:
     73     Tuple of `predictions`, `labels` and `weights`. Each of them possibly has
     74     the last dimension squeezed, `weights` could be extended by one dimension.
     75   """
     76   predictions = ops.convert_to_tensor(predictions)
     77   if labels is not None:
     78     labels, predictions = confusion_matrix.remove_squeezable_dimensions(
     79         labels, predictions)
     80     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
     81 
     82   if weights is None:
     83     return predictions, labels, None
     84 
     85   weights = ops.convert_to_tensor(weights)
     86   weights_shape = weights.get_shape()
     87   weights_rank = weights_shape.ndims
     88   if weights_rank == 0:
     89     return predictions, labels, weights
     90 
     91   predictions_shape = predictions.get_shape()
     92   predictions_rank = predictions_shape.ndims
     93   if (predictions_rank is not None) and (weights_rank is not None):
     94     # Use static rank.
     95     if weights_rank - predictions_rank == 1:
     96       weights = array_ops.squeeze(weights, [-1])
     97     elif predictions_rank - weights_rank == 1:
     98       weights = array_ops.expand_dims(weights, [-1])
     99   else:
    100     # Use dynamic rank.
    101     weights_rank_tensor = array_ops.rank(weights)
    102     rank_diff = weights_rank_tensor - array_ops.rank(predictions)
    103 
    104     def _maybe_expand_weights():
    105       return control_flow_ops.cond(
    106           math_ops.equal(rank_diff, -1),
    107           lambda: array_ops.expand_dims(weights, [-1]), lambda: weights)
    108 
    109     # Don't attempt squeeze if it will fail based on static check.
    110     if ((weights_rank is not None) and
    111         (not weights_shape.dims[-1].is_compatible_with(1))):
    112       maybe_squeeze_weights = lambda: weights
    113     else:
    114       maybe_squeeze_weights = lambda: array_ops.squeeze(weights, [-1])
    115 
    116     def _maybe_adjust_weights():
    117       return control_flow_ops.cond(
    118           math_ops.equal(rank_diff, 1), maybe_squeeze_weights,
    119           _maybe_expand_weights)
    120 
    121     # If weights are scalar, do nothing. Otherwise, try to add or remove a
    122     # dimension to match predictions.
    123     weights = control_flow_ops.cond(
    124         math_ops.equal(weights_rank_tensor, 0), lambda: weights,
    125         _maybe_adjust_weights)
    126   return predictions, labels, weights
    127 
    128 
    129 def _maybe_expand_labels(labels, predictions):
    130   """If necessary, expand `labels` along last dimension to match `predictions`.
    131 
    132   Args:
    133     labels: `Tensor` or `SparseTensor` with shape
    134       [D1, ... DN, num_labels] or [D1, ... DN]. The latter implies
    135       num_labels=1, in which case the result is an expanded `labels` with shape
    136       [D1, ... DN, 1].
    137     predictions: `Tensor` with shape [D1, ... DN, num_classes].
    138 
    139   Returns:
    140     `labels` with the same rank as `predictions`.
    141 
    142   Raises:
    143     ValueError: if `labels` has invalid shape.
    144   """
    145   with ops.name_scope(None, 'expand_labels', (labels, predictions)) as scope:
    146     labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
    147 
    148     # If sparse, expand sparse shape.
    149     if isinstance(labels, sparse_tensor.SparseTensor):
    150       return control_flow_ops.cond(
    151           math_ops.equal(
    152               array_ops.rank(predictions),
    153               array_ops.size(labels.dense_shape) + 1),
    154           lambda: sparse_ops.sparse_reshape(  # pylint: disable=g-long-lambda
    155               labels,
    156               shape=array_ops.concat((labels.dense_shape, (1,)), 0),
    157               name=scope),
    158           lambda: labels)
    159 
    160     # Otherwise, try to use static shape.
    161     labels_rank = labels.get_shape().ndims
    162     if labels_rank is not None:
    163       predictions_rank = predictions.get_shape().ndims
    164       if predictions_rank is not None:
    165         if predictions_rank == labels_rank:
    166           return labels
    167         if predictions_rank == labels_rank + 1:
    168           return array_ops.expand_dims(labels, -1, name=scope)
    169         raise ValueError(
    170             'Unexpected labels shape %s for predictions shape %s.' %
    171             (labels.get_shape(), predictions.get_shape()))
    172 
    173     # Otherwise, use dynamic shape.
    174     return control_flow_ops.cond(
    175         math_ops.equal(array_ops.rank(predictions),
    176                        array_ops.rank(labels) + 1),
    177         lambda: array_ops.expand_dims(labels, -1, name=scope), lambda: labels)
    178 
    179 
    180 def _safe_div(numerator, denominator, name):
    181   """Divides two tensors element-wise, returning 0 if the denominator is <= 0.
    182 
    183   Args:
    184     numerator: A real `Tensor`.
    185     denominator: A real `Tensor`, with dtype matching `numerator`.
    186     name: Name for the returned op.
    187 
    188   Returns:
    189     0 if `denominator` <= 0, else `numerator` / `denominator`
    190   """
    191   t = math_ops.truediv(numerator, denominator)
    192   zero = array_ops.zeros_like(t, dtype=denominator.dtype)
    193   condition = math_ops.greater(denominator, zero)
    194   zero = math_ops.cast(zero, t.dtype)
    195   return array_ops.where(condition, t, zero, name=name)
    196 
    197 
    198 def _safe_scalar_div(numerator, denominator, name):
    199   """Divides two values, returning 0 if the denominator is 0.
    200 
    201   Args:
    202     numerator: A scalar `float64` `Tensor`.
    203     denominator: A scalar `float64` `Tensor`.
    204     name: Name for the returned op.
    205 
    206   Returns:
    207     0 if `denominator` == 0, else `numerator` / `denominator`
    208   """
    209   numerator.get_shape().with_rank_at_most(1)
    210   denominator.get_shape().with_rank_at_most(1)
    211   return control_flow_ops.cond(
    212       math_ops.equal(
    213           array_ops.constant(0.0, dtype=dtypes.float64), denominator),
    214       lambda: array_ops.constant(0.0, dtype=dtypes.float64),
    215       lambda: math_ops.div(numerator, denominator),
    216       name=name)
    217 
    218 
    219 def _streaming_confusion_matrix(labels, predictions, num_classes, weights=None):
    220   """Calculate a streaming confusion matrix.
    221 
    222   Calculates a confusion matrix. For estimation over a stream of data,
    223   the function creates an  `update_op` operation.
    224 
    225   Args:
    226     labels: A `Tensor` of ground truth labels with shape [batch size] and of
    227       type `int32` or `int64`. The tensor will be flattened if its rank > 1.
    228     predictions: A `Tensor` of prediction results for semantic labels, whose
    229       shape is [batch size] and type `int32` or `int64`. The tensor will be
    230       flattened if its rank > 1.
    231     num_classes: The possible number of labels the prediction task can
    232       have. This value must be provided, since a confusion matrix of
    233       dimension = [num_classes, num_classes] will be allocated.
    234     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    235       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
    236       be either `1`, or the same as the corresponding `labels` dimension).
    237 
    238   Returns:
    239     total_cm: A `Tensor` representing the confusion matrix.
    240     update_op: An operation that increments the confusion matrix.
    241   """
    242   # Local variable to accumulate the predictions in the confusion matrix.
    243   total_cm = metric_variable(
    244       [num_classes, num_classes], dtypes.float64, name='total_confusion_matrix')
    245 
    246   # Cast the type to int64 required by confusion_matrix_ops.
    247   predictions = math_ops.to_int64(predictions)
    248   labels = math_ops.to_int64(labels)
    249   num_classes = math_ops.to_int64(num_classes)
    250 
    251   # Flatten the input if its rank > 1.
    252   if predictions.get_shape().ndims > 1:
    253     predictions = array_ops.reshape(predictions, [-1])
    254 
    255   if labels.get_shape().ndims > 1:
    256     labels = array_ops.reshape(labels, [-1])
    257 
    258   if (weights is not None) and (weights.get_shape().ndims > 1):
    259     weights = array_ops.reshape(weights, [-1])
    260 
    261   # Accumulate the prediction to current confusion matrix.
    262   current_cm = confusion_matrix.confusion_matrix(
    263       labels, predictions, num_classes, weights=weights, dtype=dtypes.float64)
    264   update_op = state_ops.assign_add(total_cm, current_cm)
    265   return total_cm, update_op
    266 
    267 
    268 @tf_export('metrics.mean')
    269 def mean(values,
    270          weights=None,
    271          metrics_collections=None,
    272          updates_collections=None,
    273          name=None):
    274   """Computes the (weighted) mean of the given values.
    275 
    276   The `mean` function creates two local variables, `total` and `count`
    277   that are used to compute the average of `values`. This average is ultimately
    278   returned as `mean` which is an idempotent operation that simply divides
    279   `total` by `count`.
    280 
    281   For estimation of the metric over a stream of data, the function creates an
    282   `update_op` operation that updates these variables and returns the `mean`.
    283   `update_op` increments `total` with the reduced sum of the product of `values`
    284   and `weights`, and it increments `count` with the reduced sum of `weights`.
    285 
    286   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
    287 
    288   Args:
    289     values: A `Tensor` of arbitrary dimensions.
    290     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    291       `values`, and must be broadcastable to `values` (i.e., all dimensions must
    292       be either `1`, or the same as the corresponding `values` dimension).
    293     metrics_collections: An optional list of collections that `mean`
    294       should be added to.
    295     updates_collections: An optional list of collections that `update_op`
    296       should be added to.
    297     name: An optional variable_scope name.
    298 
    299   Returns:
    300     mean: A `Tensor` representing the current mean, the value of `total` divided
    301       by `count`.
    302     update_op: An operation that increments the `total` and `count` variables
    303       appropriately and whose value matches `mean_value`.
    304 
    305   Raises:
    306     ValueError: If `weights` is not `None` and its shape doesn't match `values`,
    307       or if either `metrics_collections` or `updates_collections` are not a list
    308       or tuple.
    309     RuntimeError: If eager execution is enabled.
    310   """
    311   if context.in_eager_mode():
    312     raise RuntimeError('tf.metrics.mean is not supported when eager execution '
    313                        'is enabled.')
    314 
    315   with variable_scope.variable_scope(name, 'mean', (values, weights)):
    316     values = math_ops.to_float(values)
    317 
    318     total = metric_variable([], dtypes.float32, name='total')
    319     count = metric_variable([], dtypes.float32, name='count')
    320 
    321     if weights is None:
    322       num_values = math_ops.to_float(array_ops.size(values))
    323     else:
    324       values, _, weights = _remove_squeezable_dimensions(
    325           predictions=values, labels=None, weights=weights)
    326       weights = weights_broadcast_ops.broadcast_weights(
    327           math_ops.to_float(weights), values)
    328       values = math_ops.multiply(values, weights)
    329       num_values = math_ops.reduce_sum(weights)
    330 
    331     update_total_op = state_ops.assign_add(total, math_ops.reduce_sum(values))
    332     with ops.control_dependencies([values]):
    333       update_count_op = state_ops.assign_add(count, num_values)
    334 
    335     mean_t = _safe_div(total, count, 'value')
    336     update_op = _safe_div(update_total_op, update_count_op, 'update_op')
    337 
    338     if metrics_collections:
    339       ops.add_to_collections(metrics_collections, mean_t)
    340 
    341     if updates_collections:
    342       ops.add_to_collections(updates_collections, update_op)
    343 
    344     return mean_t, update_op
    345 
    346 
    347 @tf_export('metrics.accuracy')
    348 def accuracy(labels,
    349              predictions,
    350              weights=None,
    351              metrics_collections=None,
    352              updates_collections=None,
    353              name=None):
    354   """Calculates how often `predictions` matches `labels`.
    355 
    356   The `accuracy` function creates two local variables, `total` and
    357   `count` that are used to compute the frequency with which `predictions`
    358   matches `labels`. This frequency is ultimately returned as `accuracy`: an
    359   idempotent operation that simply divides `total` by `count`.
    360 
    361   For estimation of the metric over a stream of data, the function creates an
    362   `update_op` operation that updates these variables and returns the `accuracy`.
    363   Internally, an `is_correct` operation computes a `Tensor` with elements 1.0
    364   where the corresponding elements of `predictions` and `labels` match and 0.0
    365   otherwise. Then `update_op` increments `total` with the reduced sum of the
    366   product of `weights` and `is_correct`, and it increments `count` with the
    367   reduced sum of `weights`.
    368 
    369   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
    370 
    371   Args:
    372     labels: The ground truth values, a `Tensor` whose shape matches
    373       `predictions`.
    374     predictions: The predicted values, a `Tensor` of any shape.
    375     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    376       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
    377       be either `1`, or the same as the corresponding `labels` dimension).
    378     metrics_collections: An optional list of collections that `accuracy` should
    379       be added to.
    380     updates_collections: An optional list of collections that `update_op` should
    381       be added to.
    382     name: An optional variable_scope name.
    383 
    384   Returns:
    385     accuracy: A `Tensor` representing the accuracy, the value of `total` divided
    386       by `count`.
    387     update_op: An operation that increments the `total` and `count` variables
    388       appropriately and whose value matches `accuracy`.
    389 
    390   Raises:
    391     ValueError: If `predictions` and `labels` have mismatched shapes, or if
    392       `weights` is not `None` and its shape doesn't match `predictions`, or if
    393       either `metrics_collections` or `updates_collections` are not a list or
    394       tuple.
    395     RuntimeError: If eager execution is enabled.
    396   """
    397   if context.in_eager_mode():
    398     raise RuntimeError('tf.metrics.accuracy is not supported when eager '
    399                        'execution is enabled.')
    400 
    401   predictions, labels, weights = _remove_squeezable_dimensions(
    402       predictions=predictions, labels=labels, weights=weights)
    403   predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    404   if labels.dtype != predictions.dtype:
    405     predictions = math_ops.cast(predictions, labels.dtype)
    406   is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
    407   return mean(is_correct, weights, metrics_collections, updates_collections,
    408               name or 'accuracy')
    409 
    410 
    411 def _confusion_matrix_at_thresholds(labels,
    412                                     predictions,
    413                                     thresholds,
    414                                     weights=None,
    415                                     includes=None):
    416   """Computes true_positives, false_negatives, true_negatives, false_positives.
    417 
    418   This function creates up to four local variables, `true_positives`,
    419   `true_negatives`, `false_positives` and `false_negatives`.
    420   `true_positive[i]` is defined as the total weight of values in `predictions`
    421   above `thresholds[i]` whose corresponding entry in `labels` is `True`.
    422   `false_negatives[i]` is defined as the total weight of values in `predictions`
    423   at most `thresholds[i]` whose corresponding entry in `labels` is `True`.
    424   `true_negatives[i]` is defined as the total weight of values in `predictions`
    425   at most `thresholds[i]` whose corresponding entry in `labels` is `False`.
    426   `false_positives[i]` is defined as the total weight of values in `predictions`
    427   above `thresholds[i]` whose corresponding entry in `labels` is `False`.
    428 
    429   For estimation of these metrics over a stream of data, for each metric the
    430   function respectively creates an `update_op` operation that updates the
    431   variable and returns its value.
    432 
    433   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
    434 
    435   Args:
    436     labels: A `Tensor` whose shape matches `predictions`. Will be cast to
    437       `bool`.
    438     predictions: A floating point `Tensor` of arbitrary shape and whose values
    439       are in the range `[0, 1]`.
    440     thresholds: A python list or tuple of float thresholds in `[0, 1]`.
    441     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    442       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
    443       be either `1`, or the same as the corresponding `labels` dimension).
    444     includes: Tuple of keys to return, from 'tp', 'fn', 'tn', fp'. If `None`,
    445         default to all four.
    446 
    447   Returns:
    448     values: Dict of variables of shape `[len(thresholds)]`. Keys are from
    449         `includes`.
    450     update_ops: Dict of operations that increments the `values`. Keys are from
    451         `includes`.
    452 
    453   Raises:
    454     ValueError: If `predictions` and `labels` have mismatched shapes, or if
    455       `weights` is not `None` and its shape doesn't match `predictions`, or if
    456       `includes` contains invalid keys.
    457   """
    458   all_includes = ('tp', 'fn', 'tn', 'fp')
    459   if includes is None:
    460     includes = all_includes
    461   else:
    462     for include in includes:
    463       if include not in all_includes:
    464         raise ValueError('Invalid key: %s.' % include)
    465 
    466   with ops.control_dependencies([
    467       check_ops.assert_greater_equal(
    468           predictions,
    469           math_ops.cast(0.0, dtype=predictions.dtype),
    470           message='predictions must be in [0, 1]'),
    471       check_ops.assert_less_equal(
    472           predictions,
    473           math_ops.cast(1.0, dtype=predictions.dtype),
    474           message='predictions must be in [0, 1]')
    475   ]):
    476     predictions, labels, weights = _remove_squeezable_dimensions(
    477         predictions=math_ops.to_float(predictions),
    478         labels=math_ops.cast(labels, dtype=dtypes.bool),
    479         weights=weights)
    480 
    481   num_thresholds = len(thresholds)
    482 
    483   # Reshape predictions and labels.
    484   predictions_2d = array_ops.reshape(predictions, [-1, 1])
    485   labels_2d = array_ops.reshape(
    486       math_ops.cast(labels, dtype=dtypes.bool), [1, -1])
    487 
    488   # Use static shape if known.
    489   num_predictions = predictions_2d.get_shape().as_list()[0]
    490 
    491   # Otherwise use dynamic shape.
    492   if num_predictions is None:
    493     num_predictions = array_ops.shape(predictions_2d)[0]
    494   thresh_tiled = array_ops.tile(
    495       array_ops.expand_dims(array_ops.constant(thresholds), [1]),
    496       array_ops.stack([1, num_predictions]))
    497 
    498   # Tile the predictions after thresholding them across different thresholds.
    499   pred_is_pos = math_ops.greater(
    500       array_ops.tile(array_ops.transpose(predictions_2d), [num_thresholds, 1]),
    501       thresh_tiled)
    502   if ('fn' in includes) or ('tn' in includes):
    503     pred_is_neg = math_ops.logical_not(pred_is_pos)
    504 
    505   # Tile labels by number of thresholds
    506   label_is_pos = array_ops.tile(labels_2d, [num_thresholds, 1])
    507   if ('fp' in includes) or ('tn' in includes):
    508     label_is_neg = math_ops.logical_not(label_is_pos)
    509 
    510   if weights is not None:
    511     weights = weights_broadcast_ops.broadcast_weights(
    512         math_ops.to_float(weights), predictions)
    513     weights_tiled = array_ops.tile(
    514         array_ops.reshape(weights, [1, -1]), [num_thresholds, 1])
    515     thresh_tiled.get_shape().assert_is_compatible_with(
    516         weights_tiled.get_shape())
    517   else:
    518     weights_tiled = None
    519 
    520   values = {}
    521   update_ops = {}
    522 
    523   if 'tp' in includes:
    524     true_p = metric_variable(
    525         [num_thresholds], dtypes.float32, name='true_positives')
    526     is_true_positive = math_ops.to_float(
    527         math_ops.logical_and(label_is_pos, pred_is_pos))
    528     if weights_tiled is not None:
    529       is_true_positive *= weights_tiled
    530     update_ops['tp'] = state_ops.assign_add(true_p,
    531                                             math_ops.reduce_sum(
    532                                                 is_true_positive, 1))
    533     values['tp'] = true_p
    534 
    535   if 'fn' in includes:
    536     false_n = metric_variable(
    537         [num_thresholds], dtypes.float32, name='false_negatives')
    538     is_false_negative = math_ops.to_float(
    539         math_ops.logical_and(label_is_pos, pred_is_neg))
    540     if weights_tiled is not None:
    541       is_false_negative *= weights_tiled
    542     update_ops['fn'] = state_ops.assign_add(false_n,
    543                                             math_ops.reduce_sum(
    544                                                 is_false_negative, 1))
    545     values['fn'] = false_n
    546 
    547   if 'tn' in includes:
    548     true_n = metric_variable(
    549         [num_thresholds], dtypes.float32, name='true_negatives')
    550     is_true_negative = math_ops.to_float(
    551         math_ops.logical_and(label_is_neg, pred_is_neg))
    552     if weights_tiled is not None:
    553       is_true_negative *= weights_tiled
    554     update_ops['tn'] = state_ops.assign_add(true_n,
    555                                             math_ops.reduce_sum(
    556                                                 is_true_negative, 1))
    557     values['tn'] = true_n
    558 
    559   if 'fp' in includes:
    560     false_p = metric_variable(
    561         [num_thresholds], dtypes.float32, name='false_positives')
    562     is_false_positive = math_ops.to_float(
    563         math_ops.logical_and(label_is_neg, pred_is_pos))
    564     if weights_tiled is not None:
    565       is_false_positive *= weights_tiled
    566     update_ops['fp'] = state_ops.assign_add(false_p,
    567                                             math_ops.reduce_sum(
    568                                                 is_false_positive, 1))
    569     values['fp'] = false_p
    570 
    571   return values, update_ops
    572 
    573 
    574 @tf_export('metrics.auc')
    575 def auc(labels,
    576         predictions,
    577         weights=None,
    578         num_thresholds=200,
    579         metrics_collections=None,
    580         updates_collections=None,
    581         curve='ROC',
    582         name=None,
    583         summation_method='trapezoidal'):
    584   """Computes the approximate AUC via a Riemann sum.
    585 
    586   The `auc` function creates four local variables, `true_positives`,
    587   `true_negatives`, `false_positives` and `false_negatives` that are used to
    588   compute the AUC. To discretize the AUC curve, a linearly spaced set of
    589   thresholds is used to compute pairs of recall and precision values. The area
    590   under the ROC-curve is therefore computed using the height of the recall
    591   values by the false positive rate, while the area under the PR-curve is the
    592   computed using the height of the precision values by the recall.
    593 
    594   This value is ultimately returned as `auc`, an idempotent operation that
    595   computes the area under a discretized curve of precision versus recall values
    596   (computed using the aforementioned variables). The `num_thresholds` variable
    597   controls the degree of discretization with larger numbers of thresholds more
    598   closely approximating the true AUC. The quality of the approximation may vary
    599   dramatically depending on `num_thresholds`.
    600 
    601   For best results, `predictions` should be distributed approximately uniformly
    602   in the range [0, 1] and not peaked around 0 or 1. The quality of the AUC
    603   approximation may be poor if this is not the case. Setting `summation_method`
    604   to 'minoring' or 'majoring' can help quantify the error in the approximation
    605   by providing lower or upper bound estimate of the AUC.
    606 
    607   For estimation of the metric over a stream of data, the function creates an
    608   `update_op` operation that updates these variables and returns the `auc`.
    609 
    610   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
    611 
    612   Args:
    613     labels: A `Tensor` whose shape matches `predictions`. Will be cast to
    614       `bool`.
    615     predictions: A floating point `Tensor` of arbitrary shape and whose values
    616       are in the range `[0, 1]`.
    617     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    618       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
    619       be either `1`, or the same as the corresponding `labels` dimension).
    620     num_thresholds: The number of thresholds to use when discretizing the roc
    621       curve.
    622     metrics_collections: An optional list of collections that `auc` should be
    623       added to.
    624     updates_collections: An optional list of collections that `update_op` should
    625       be added to.
    626     curve: Specifies the name of the curve to be computed, 'ROC' [default] or
    627       'PR' for the Precision-Recall-curve.
    628     name: An optional variable_scope name.
    629     summation_method: Specifies the Riemann summation method used, 'trapezoidal'
    630       [default] that applies the trapezoidal rule, 'minoring' that applies
    631       left summation for increasing intervals and right summation for decreasing
    632       intervals or 'majoring' that applies the opposite.
    633 
    634   Returns:
    635     auc: A scalar `Tensor` representing the current area-under-curve.
    636     update_op: An operation that increments the `true_positives`,
    637       `true_negatives`, `false_positives` and `false_negatives` variables
    638       appropriately and whose value matches `auc`.
    639 
    640   Raises:
    641     ValueError: If `predictions` and `labels` have mismatched shapes, or if
    642       `weights` is not `None` and its shape doesn't match `predictions`, or if
    643       either `metrics_collections` or `updates_collections` are not a list or
    644       tuple.
    645     RuntimeError: If eager execution is enabled.
    646   """
    647   if context.in_eager_mode():
    648     raise RuntimeError('tf.metrics.auc is not supported when eager execution '
    649                        'is enabled.')
    650 
    651   with variable_scope.variable_scope(name, 'auc',
    652                                      (labels, predictions, weights)):
    653     if curve != 'ROC' and curve != 'PR':
    654       raise ValueError('curve must be either ROC or PR, %s unknown' % (curve))
    655     kepsilon = 1e-7  # to account for floating point imprecisions
    656     thresholds = [
    657         (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
    658     ]
    659     thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
    660 
    661     values, update_ops = _confusion_matrix_at_thresholds(
    662         labels, predictions, thresholds, weights)
    663 
    664     # Add epsilons to avoid dividing by 0.
    665     epsilon = 1.0e-6
    666 
    667     def compute_auc(tp, fn, tn, fp, name):
    668       """Computes the roc-auc or pr-auc based on confusion counts."""
    669       rec = math_ops.div(tp + epsilon, tp + fn + epsilon)
    670       if curve == 'ROC':
    671         fp_rate = math_ops.div(fp, fp + tn + epsilon)
    672         x = fp_rate
    673         y = rec
    674       else:  # curve == 'PR'.
    675         prec = math_ops.div(tp, tp + fp + epsilon)
    676         x = rec
    677         y = prec
    678       if summation_method == 'trapezoidal':
    679         return math_ops.reduce_sum(
    680             math_ops.multiply(x[:num_thresholds - 1] - x[1:],
    681                               (y[:num_thresholds - 1] + y[1:]) / 2.),
    682             name=name)
    683       elif summation_method == 'minoring':
    684         return math_ops.reduce_sum(
    685             math_ops.multiply(x[:num_thresholds - 1] - x[1:],
    686                               math_ops.minimum(y[:num_thresholds - 1], y[1:])),
    687             name=name)
    688       elif summation_method == 'majoring':
    689         return math_ops.reduce_sum(
    690             math_ops.multiply(x[:num_thresholds - 1] - x[1:],
    691                               math_ops.maximum(y[:num_thresholds - 1], y[1:])),
    692             name=name)
    693       else:
    694         raise ValueError('Invalid summation_method: %s' % summation_method)
    695 
    696     # sum up the areas of all the trapeziums
    697     auc_value = compute_auc(values['tp'], values['fn'], values['tn'],
    698                             values['fp'], 'value')
    699     update_op = compute_auc(update_ops['tp'], update_ops['fn'],
    700                             update_ops['tn'], update_ops['fp'], 'update_op')
    701 
    702     if metrics_collections:
    703       ops.add_to_collections(metrics_collections, auc_value)
    704 
    705     if updates_collections:
    706       ops.add_to_collections(updates_collections, update_op)
    707 
    708     return auc_value, update_op
    709 
    710 
    711 @tf_export('metrics.mean_absolute_error')
    712 def mean_absolute_error(labels,
    713                         predictions,
    714                         weights=None,
    715                         metrics_collections=None,
    716                         updates_collections=None,
    717                         name=None):
    718   """Computes the mean absolute error between the labels and predictions.
    719 
    720   The `mean_absolute_error` function creates two local variables,
    721   `total` and `count` that are used to compute the mean absolute error. This
    722   average is weighted by `weights`, and it is ultimately returned as
    723   `mean_absolute_error`: an idempotent operation that simply divides `total` by
    724   `count`.
    725 
    726   For estimation of the metric over a stream of data, the function creates an
    727   `update_op` operation that updates these variables and returns the
    728   `mean_absolute_error`. Internally, an `absolute_errors` operation computes the
    729   absolute value of the differences between `predictions` and `labels`. Then
    730   `update_op` increments `total` with the reduced sum of the product of
    731   `weights` and `absolute_errors`, and it increments `count` with the reduced
    732   sum of `weights`
    733 
    734   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
    735 
    736   Args:
    737     labels: A `Tensor` of the same shape as `predictions`.
    738     predictions: A `Tensor` of arbitrary shape.
    739     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    740       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
    741       be either `1`, or the same as the corresponding `labels` dimension).
    742     metrics_collections: An optional list of collections that
    743       `mean_absolute_error` should be added to.
    744     updates_collections: An optional list of collections that `update_op` should
    745       be added to.
    746     name: An optional variable_scope name.
    747 
    748   Returns:
    749     mean_absolute_error: A `Tensor` representing the current mean, the value of
    750       `total` divided by `count`.
    751     update_op: An operation that increments the `total` and `count` variables
    752       appropriately and whose value matches `mean_absolute_error`.
    753 
    754   Raises:
    755     ValueError: If `predictions` and `labels` have mismatched shapes, or if
    756       `weights` is not `None` and its shape doesn't match `predictions`, or if
    757       either `metrics_collections` or `updates_collections` are not a list or
    758       tuple.
    759     RuntimeError: If eager execution is enabled.
    760   """
    761   if context.in_eager_mode():
    762     raise RuntimeError('tf.metrics.mean_absolute_error is not supported '
    763                        'when eager execution is enabled.')
    764 
    765   predictions, labels, weights = _remove_squeezable_dimensions(
    766       predictions=predictions, labels=labels, weights=weights)
    767   absolute_errors = math_ops.abs(predictions - labels)
    768   return mean(absolute_errors, weights, metrics_collections,
    769               updates_collections, name or 'mean_absolute_error')
    770 
    771 
    772 @tf_export('metrics.mean_cosine_distance')
    773 def mean_cosine_distance(labels,
    774                          predictions,
    775                          dim,
    776                          weights=None,
    777                          metrics_collections=None,
    778                          updates_collections=None,
    779                          name=None):
    780   """Computes the cosine distance between the labels and predictions.
    781 
    782   The `mean_cosine_distance` function creates two local variables,
    783   `total` and `count` that are used to compute the average cosine distance
    784   between `predictions` and `labels`. This average is weighted by `weights`,
    785   and it is ultimately returned as `mean_distance`, which is an idempotent
    786   operation that simply divides `total` by `count`.
    787 
    788   For estimation of the metric over a stream of data, the function creates an
    789   `update_op` operation that updates these variables and returns the
    790   `mean_distance`.
    791 
    792   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
    793 
    794   Args:
    795     labels: A `Tensor` of arbitrary shape.
    796     predictions: A `Tensor` of the same shape as `labels`.
    797     dim: The dimension along which the cosine distance is computed.
    798     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    799       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
    800       be either `1`, or the same as the corresponding `labels` dimension). Also,
    801       dimension `dim` must be `1`.
    802     metrics_collections: An optional list of collections that the metric
    803       value variable should be added to.
    804     updates_collections: An optional list of collections that the metric update
    805       ops should be added to.
    806     name: An optional variable_scope name.
    807 
    808   Returns:
    809     mean_distance: A `Tensor` representing the current mean, the value of
    810       `total` divided by `count`.
    811     update_op: An operation that increments the `total` and `count` variables
    812       appropriately.
    813 
    814   Raises:
    815     ValueError: If `predictions` and `labels` have mismatched shapes, or if
    816       `weights` is not `None` and its shape doesn't match `predictions`, or if
    817       either `metrics_collections` or `updates_collections` are not a list or
    818       tuple.
    819     RuntimeError: If eager execution is enabled.
    820   """
    821   if context.in_eager_mode():
    822     raise RuntimeError('tf.metrics.mean_cosine_distance is not supported when '
    823                        'eager execution is enabled.')
    824 
    825   predictions, labels, weights = _remove_squeezable_dimensions(
    826       predictions=predictions, labels=labels, weights=weights)
    827   radial_diffs = math_ops.multiply(predictions, labels)
    828   radial_diffs = math_ops.reduce_sum(
    829       radial_diffs, reduction_indices=[
    830           dim,
    831       ], keepdims=True)
    832   mean_distance, update_op = mean(radial_diffs, weights, None, None, name or
    833                                   'mean_cosine_distance')
    834   mean_distance = math_ops.subtract(1.0, mean_distance)
    835   update_op = math_ops.subtract(1.0, update_op)
    836 
    837   if metrics_collections:
    838     ops.add_to_collections(metrics_collections, mean_distance)
    839 
    840   if updates_collections:
    841     ops.add_to_collections(updates_collections, update_op)
    842 
    843   return mean_distance, update_op
    844 
    845 
    846 @tf_export('metrics.mean_per_class_accuracy')
    847 def mean_per_class_accuracy(labels,
    848                             predictions,
    849                             num_classes,
    850                             weights=None,
    851                             metrics_collections=None,
    852                             updates_collections=None,
    853                             name=None):
    854   """Calculates the mean of the per-class accuracies.
    855 
    856   Calculates the accuracy for each class, then takes the mean of that.
    857 
    858   For estimation of the metric over a stream of data, the function creates an
    859   `update_op` operation that updates the accuracy of each class and returns
    860   them.
    861 
    862   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
    863 
    864   Args:
    865     labels: A `Tensor` of ground truth labels with shape [batch size] and of
    866       type `int32` or `int64`. The tensor will be flattened if its rank > 1.
    867     predictions: A `Tensor` of prediction results for semantic labels, whose
    868       shape is [batch size] and type `int32` or `int64`. The tensor will be
    869       flattened if its rank > 1.
    870     num_classes: The possible number of labels the prediction task can
    871       have. This value must be provided, since two variables with shape =
    872       [num_classes] will be allocated.
    873     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    874       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
    875       be either `1`, or the same as the corresponding `labels` dimension).
    876     metrics_collections: An optional list of collections that
    877       `mean_per_class_accuracy'
    878       should be added to.
    879     updates_collections: An optional list of collections `update_op` should be
    880       added to.
    881     name: An optional variable_scope name.
    882 
    883   Returns:
    884     mean_accuracy: A `Tensor` representing the mean per class accuracy.
    885     update_op: An operation that updates the accuracy tensor.
    886 
    887   Raises:
    888     ValueError: If `predictions` and `labels` have mismatched shapes, or if
    889       `weights` is not `None` and its shape doesn't match `predictions`, or if
    890       either `metrics_collections` or `updates_collections` are not a list or
    891       tuple.
    892     RuntimeError: If eager execution is enabled.
    893   """
    894   if context.in_eager_mode():
    895     raise RuntimeError('tf.metrics.mean_per_class_accuracy is not supported '
    896                        'when eager execution is enabled.')
    897 
    898   with variable_scope.variable_scope(name, 'mean_accuracy',
    899                                      (predictions, labels, weights)):
    900     labels = math_ops.to_int64(labels)
    901 
    902     # Flatten the input if its rank > 1.
    903     if labels.get_shape().ndims > 1:
    904       labels = array_ops.reshape(labels, [-1])
    905 
    906     if predictions.get_shape().ndims > 1:
    907       predictions = array_ops.reshape(predictions, [-1])
    908 
    909     # Check if shape is compatible.
    910     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    911 
    912     total = metric_variable([num_classes], dtypes.float32, name='total')
    913     count = metric_variable([num_classes], dtypes.float32, name='count')
    914 
    915     ones = array_ops.ones([array_ops.size(labels)], dtypes.float32)
    916 
    917     if labels.dtype != predictions.dtype:
    918       predictions = math_ops.cast(predictions, labels.dtype)
    919     is_correct = math_ops.to_float(math_ops.equal(predictions, labels))
    920 
    921     if weights is not None:
    922       if weights.get_shape().ndims > 1:
    923         weights = array_ops.reshape(weights, [-1])
    924       weights = math_ops.to_float(weights)
    925 
    926       is_correct *= weights
    927       ones *= weights
    928 
    929     update_total_op = state_ops.scatter_add(total, labels, ones)
    930     update_count_op = state_ops.scatter_add(count, labels, is_correct)
    931 
    932     per_class_accuracy = _safe_div(count, total, None)
    933 
    934     mean_accuracy_v = math_ops.reduce_mean(
    935         per_class_accuracy, name='mean_accuracy')
    936     update_op = _safe_div(update_count_op, update_total_op, name='update_op')
    937 
    938     if metrics_collections:
    939       ops.add_to_collections(metrics_collections, mean_accuracy_v)
    940 
    941     if updates_collections:
    942       ops.add_to_collections(updates_collections, update_op)
    943 
    944     return mean_accuracy_v, update_op
    945 
    946 
    947 @tf_export('metrics.mean_iou')
    948 def mean_iou(labels,
    949              predictions,
    950              num_classes,
    951              weights=None,
    952              metrics_collections=None,
    953              updates_collections=None,
    954              name=None):
    955   """Calculate per-step mean Intersection-Over-Union (mIOU).
    956 
    957   Mean Intersection-Over-Union is a common evaluation metric for
    958   semantic image segmentation, which first computes the IOU for each
    959   semantic class and then computes the average over classes.
    960   IOU is defined as follows:
    961     IOU = true_positive / (true_positive + false_positive + false_negative).
    962   The predictions are accumulated in a confusion matrix, weighted by `weights`,
    963   and mIOU is then calculated from it.
    964 
    965   For estimation of the metric over a stream of data, the function creates an
    966   `update_op` operation that updates these variables and returns the `mean_iou`.
    967 
    968   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
    969 
    970   Args:
    971     labels: A `Tensor` of ground truth labels with shape [batch size] and of
    972       type `int32` or `int64`. The tensor will be flattened if its rank > 1.
    973     predictions: A `Tensor` of prediction results for semantic labels, whose
    974       shape is [batch size] and type `int32` or `int64`. The tensor will be
    975       flattened if its rank > 1.
    976     num_classes: The possible number of labels the prediction task can
    977       have. This value must be provided, since a confusion matrix of
    978       dimension = [num_classes, num_classes] will be allocated.
    979     weights: Optional `Tensor` whose rank is either 0, or the same rank as
    980       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
    981       be either `1`, or the same as the corresponding `labels` dimension).
    982     metrics_collections: An optional list of collections that `mean_iou`
    983       should be added to.
    984     updates_collections: An optional list of collections `update_op` should be
    985       added to.
    986     name: An optional variable_scope name.
    987 
    988   Returns:
    989     mean_iou: A `Tensor` representing the mean intersection-over-union.
    990     update_op: An operation that increments the confusion matrix.
    991 
    992   Raises:
    993     ValueError: If `predictions` and `labels` have mismatched shapes, or if
    994       `weights` is not `None` and its shape doesn't match `predictions`, or if
    995       either `metrics_collections` or `updates_collections` are not a list or
    996       tuple.
    997     RuntimeError: If eager execution is enabled.
    998   """
    999   if context.in_eager_mode():
   1000     raise RuntimeError('tf.metrics.mean_iou is not supported when '
   1001                        'eager execution is enabled.')
   1002 
   1003   with variable_scope.variable_scope(name, 'mean_iou',
   1004                                      (predictions, labels, weights)):
   1005     # Check if shape is compatible.
   1006     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
   1007 
   1008     total_cm, update_op = _streaming_confusion_matrix(labels, predictions,
   1009                                                       num_classes, weights)
   1010 
   1011     def compute_mean_iou(name):
   1012       """Compute the mean intersection-over-union via the confusion matrix."""
   1013       sum_over_row = math_ops.to_float(math_ops.reduce_sum(total_cm, 0))
   1014       sum_over_col = math_ops.to_float(math_ops.reduce_sum(total_cm, 1))
   1015       cm_diag = math_ops.to_float(array_ops.diag_part(total_cm))
   1016       denominator = sum_over_row + sum_over_col - cm_diag
   1017 
   1018       # The mean is only computed over classes that appear in the
   1019       # label or prediction tensor. If the denominator is 0, we need to
   1020       # ignore the class.
   1021       num_valid_entries = math_ops.reduce_sum(
   1022           math_ops.cast(
   1023               math_ops.not_equal(denominator, 0), dtype=dtypes.float32))
   1024 
   1025       # If the value of the denominator is 0, set it to 1 to avoid
   1026       # zero division.
   1027       denominator = array_ops.where(
   1028           math_ops.greater(denominator, 0), denominator,
   1029           array_ops.ones_like(denominator))
   1030       iou = math_ops.div(cm_diag, denominator)
   1031 
   1032       # If the number of valid entries is 0 (no classes) we return 0.
   1033       result = array_ops.where(
   1034           math_ops.greater(num_valid_entries, 0),
   1035           math_ops.reduce_sum(iou, name=name) / num_valid_entries, 0)
   1036       return result
   1037 
   1038     mean_iou_v = compute_mean_iou('mean_iou')
   1039 
   1040     if metrics_collections:
   1041       ops.add_to_collections(metrics_collections, mean_iou_v)
   1042 
   1043     if updates_collections:
   1044       ops.add_to_collections(updates_collections, update_op)
   1045 
   1046     return mean_iou_v, update_op
   1047 
   1048 
   1049 @tf_export('metrics.mean_relative_error')
   1050 def mean_relative_error(labels,
   1051                         predictions,
   1052                         normalizer,
   1053                         weights=None,
   1054                         metrics_collections=None,
   1055                         updates_collections=None,
   1056                         name=None):
   1057   """Computes the mean relative error by normalizing with the given values.
   1058 
   1059   The `mean_relative_error` function creates two local variables,
   1060   `total` and `count` that are used to compute the mean relative absolute error.
   1061   This average is weighted by `weights`, and it is ultimately returned as
   1062   `mean_relative_error`: an idempotent operation that simply divides `total` by
   1063   `count`.
   1064 
   1065   For estimation of the metric over a stream of data, the function creates an
   1066   `update_op` operation that updates these variables and returns the
   1067   `mean_reative_error`. Internally, a `relative_errors` operation divides the
   1068   absolute value of the differences between `predictions` and `labels` by the
   1069   `normalizer`. Then `update_op` increments `total` with the reduced sum of the
   1070   product of `weights` and `relative_errors`, and it increments `count` with the
   1071   reduced sum of `weights`.
   1072 
   1073   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1074 
   1075   Args:
   1076     labels: A `Tensor` of the same shape as `predictions`.
   1077     predictions: A `Tensor` of arbitrary shape.
   1078     normalizer: A `Tensor` of the same shape as `predictions`.
   1079     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1080       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1081       be either `1`, or the same as the corresponding `labels` dimension).
   1082     metrics_collections: An optional list of collections that
   1083       `mean_relative_error` should be added to.
   1084     updates_collections: An optional list of collections that `update_op` should
   1085       be added to.
   1086     name: An optional variable_scope name.
   1087 
   1088   Returns:
   1089     mean_relative_error: A `Tensor` representing the current mean, the value of
   1090       `total` divided by `count`.
   1091     update_op: An operation that increments the `total` and `count` variables
   1092       appropriately and whose value matches `mean_relative_error`.
   1093 
   1094   Raises:
   1095     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1096       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1097       either `metrics_collections` or `updates_collections` are not a list or
   1098       tuple.
   1099     RuntimeError: If eager execution is enabled.
   1100   """
   1101   if context.in_eager_mode():
   1102     raise RuntimeError('tf.metrics.mean_relative_error is not supported when '
   1103                        'eager execution is enabled.')
   1104 
   1105   predictions, labels, weights = _remove_squeezable_dimensions(
   1106       predictions=predictions, labels=labels, weights=weights)
   1107 
   1108   predictions, normalizer = confusion_matrix.remove_squeezable_dimensions(
   1109       predictions, normalizer)
   1110   predictions.get_shape().assert_is_compatible_with(normalizer.get_shape())
   1111   relative_errors = array_ops.where(
   1112       math_ops.equal(normalizer, 0.0), array_ops.zeros_like(labels),
   1113       math_ops.div(math_ops.abs(labels - predictions), normalizer))
   1114   return mean(relative_errors, weights, metrics_collections,
   1115               updates_collections, name or 'mean_relative_error')
   1116 
   1117 
   1118 @tf_export('metrics.mean_squared_error')
   1119 def mean_squared_error(labels,
   1120                        predictions,
   1121                        weights=None,
   1122                        metrics_collections=None,
   1123                        updates_collections=None,
   1124                        name=None):
   1125   """Computes the mean squared error between the labels and predictions.
   1126 
   1127   The `mean_squared_error` function creates two local variables,
   1128   `total` and `count` that are used to compute the mean squared error.
   1129   This average is weighted by `weights`, and it is ultimately returned as
   1130   `mean_squared_error`: an idempotent operation that simply divides `total` by
   1131   `count`.
   1132 
   1133   For estimation of the metric over a stream of data, the function creates an
   1134   `update_op` operation that updates these variables and returns the
   1135   `mean_squared_error`. Internally, a `squared_error` operation computes the
   1136   element-wise square of the difference between `predictions` and `labels`. Then
   1137   `update_op` increments `total` with the reduced sum of the product of
   1138   `weights` and `squared_error`, and it increments `count` with the reduced sum
   1139   of `weights`.
   1140 
   1141   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1142 
   1143   Args:
   1144     labels: A `Tensor` of the same shape as `predictions`.
   1145     predictions: A `Tensor` of arbitrary shape.
   1146     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1147       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1148       be either `1`, or the same as the corresponding `labels` dimension).
   1149     metrics_collections: An optional list of collections that
   1150       `mean_squared_error` should be added to.
   1151     updates_collections: An optional list of collections that `update_op` should
   1152       be added to.
   1153     name: An optional variable_scope name.
   1154 
   1155   Returns:
   1156     mean_squared_error: A `Tensor` representing the current mean, the value of
   1157       `total` divided by `count`.
   1158     update_op: An operation that increments the `total` and `count` variables
   1159       appropriately and whose value matches `mean_squared_error`.
   1160 
   1161   Raises:
   1162     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1163       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1164       either `metrics_collections` or `updates_collections` are not a list or
   1165       tuple.
   1166     RuntimeError: If eager execution is enabled.
   1167   """
   1168   if context.in_eager_mode():
   1169     raise RuntimeError('tf.metrics.mean_squared_error is not supported when '
   1170                        'eager execution is enabled.')
   1171 
   1172   predictions, labels, weights = _remove_squeezable_dimensions(
   1173       predictions=predictions, labels=labels, weights=weights)
   1174   squared_error = math_ops.square(labels - predictions)
   1175   return mean(squared_error, weights, metrics_collections, updates_collections,
   1176               name or 'mean_squared_error')
   1177 
   1178 
   1179 @tf_export('metrics.mean_tensor')
   1180 def mean_tensor(values,
   1181                 weights=None,
   1182                 metrics_collections=None,
   1183                 updates_collections=None,
   1184                 name=None):
   1185   """Computes the element-wise (weighted) mean of the given tensors.
   1186 
   1187   In contrast to the `mean` function which returns a scalar with the
   1188   mean,  this function returns an average tensor with the same shape as the
   1189   input tensors.
   1190 
   1191   The `mean_tensor` function creates two local variables,
   1192   `total_tensor` and `count_tensor` that are used to compute the average of
   1193   `values`. This average is ultimately returned as `mean` which is an idempotent
   1194   operation that simply divides `total` by `count`.
   1195 
   1196   For estimation of the metric over a stream of data, the function creates an
   1197   `update_op` operation that updates these variables and returns the `mean`.
   1198   `update_op` increments `total` with the reduced sum of the product of `values`
   1199   and `weights`, and it increments `count` with the reduced sum of `weights`.
   1200 
   1201   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1202 
   1203   Args:
   1204     values: A `Tensor` of arbitrary dimensions.
   1205     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1206       `values`, and must be broadcastable to `values` (i.e., all dimensions must
   1207       be either `1`, or the same as the corresponding `values` dimension).
   1208     metrics_collections: An optional list of collections that `mean`
   1209       should be added to.
   1210     updates_collections: An optional list of collections that `update_op`
   1211       should be added to.
   1212     name: An optional variable_scope name.
   1213 
   1214   Returns:
   1215     mean: A float `Tensor` representing the current mean, the value of `total`
   1216       divided by `count`.
   1217     update_op: An operation that increments the `total` and `count` variables
   1218       appropriately and whose value matches `mean_value`.
   1219 
   1220   Raises:
   1221     ValueError: If `weights` is not `None` and its shape doesn't match `values`,
   1222       or if either `metrics_collections` or `updates_collections` are not a list
   1223       or tuple.
   1224     RuntimeError: If eager execution is enabled.
   1225   """
   1226   if context.in_eager_mode():
   1227     raise RuntimeError('tf.metrics.mean_tensor is not supported when '
   1228                        'eager execution is enabled.')
   1229 
   1230   with variable_scope.variable_scope(name, 'mean', (values, weights)):
   1231     values = math_ops.to_float(values)
   1232     total = metric_variable(
   1233         values.get_shape(), dtypes.float32, name='total_tensor')
   1234     count = metric_variable(
   1235         values.get_shape(), dtypes.float32, name='count_tensor')
   1236 
   1237     num_values = array_ops.ones_like(values)
   1238     if weights is not None:
   1239       values, _, weights = _remove_squeezable_dimensions(
   1240           predictions=values, labels=None, weights=weights)
   1241       weights = weights_broadcast_ops.broadcast_weights(
   1242           math_ops.to_float(weights), values)
   1243       values = math_ops.multiply(values, weights)
   1244       num_values = math_ops.multiply(num_values, weights)
   1245 
   1246     update_total_op = state_ops.assign_add(total, values)
   1247     with ops.control_dependencies([values]):
   1248       update_count_op = state_ops.assign_add(count, num_values)
   1249 
   1250     def compute_mean(total, count, name):
   1251       non_zero_count = math_ops.maximum(
   1252           count, array_ops.ones_like(count), name=name)
   1253       return math_ops.truediv(total, non_zero_count, name=name)
   1254 
   1255     mean_t = compute_mean(total, count, 'value')
   1256     update_op = compute_mean(update_total_op, update_count_op, 'update_op')
   1257 
   1258     if metrics_collections:
   1259       ops.add_to_collections(metrics_collections, mean_t)
   1260 
   1261     if updates_collections:
   1262       ops.add_to_collections(updates_collections, update_op)
   1263 
   1264     return mean_t, update_op
   1265 
   1266 
   1267 @tf_export('metrics.percentage_below')
   1268 def percentage_below(values,
   1269                      threshold,
   1270                      weights=None,
   1271                      metrics_collections=None,
   1272                      updates_collections=None,
   1273                      name=None):
   1274   """Computes the percentage of values less than the given threshold.
   1275 
   1276   The `percentage_below` function creates two local variables,
   1277   `total` and `count` that are used to compute the percentage of `values` that
   1278   fall below `threshold`. This rate is weighted by `weights`, and it is
   1279   ultimately returned as `percentage` which is an idempotent operation that
   1280   simply divides `total` by `count`.
   1281 
   1282   For estimation of the metric over a stream of data, the function creates an
   1283   `update_op` operation that updates these variables and returns the
   1284   `percentage`.
   1285 
   1286   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1287 
   1288   Args:
   1289     values: A numeric `Tensor` of arbitrary size.
   1290     threshold: A scalar threshold.
   1291     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1292       `values`, and must be broadcastable to `values` (i.e., all dimensions must
   1293       be either `1`, or the same as the corresponding `values` dimension).
   1294     metrics_collections: An optional list of collections that the metric
   1295       value variable should be added to.
   1296     updates_collections: An optional list of collections that the metric update
   1297       ops should be added to.
   1298     name: An optional variable_scope name.
   1299 
   1300   Returns:
   1301     percentage: A `Tensor` representing the current mean, the value of `total`
   1302       divided by `count`.
   1303     update_op: An operation that increments the `total` and `count` variables
   1304       appropriately.
   1305 
   1306   Raises:
   1307     ValueError: If `weights` is not `None` and its shape doesn't match `values`,
   1308       or if either `metrics_collections` or `updates_collections` are not a list
   1309       or tuple.
   1310     RuntimeError: If eager execution is enabled.
   1311   """
   1312   if context.in_eager_mode():
   1313     raise RuntimeError('tf.metrics.percentage_below is not supported when '
   1314                        'eager execution is enabled.')
   1315 
   1316   is_below_threshold = math_ops.to_float(math_ops.less(values, threshold))
   1317   return mean(is_below_threshold, weights, metrics_collections,
   1318               updates_collections, name or 'percentage_below_threshold')
   1319 
   1320 
   1321 def _count_condition(values,
   1322                      weights=None,
   1323                      metrics_collections=None,
   1324                      updates_collections=None):
   1325   """Sums the weights of cases where the given values are True.
   1326 
   1327   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1328 
   1329   Args:
   1330     values: A `bool` `Tensor` of arbitrary size.
   1331     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1332       `values`, and must be broadcastable to `values` (i.e., all dimensions must
   1333       be either `1`, or the same as the corresponding `values` dimension).
   1334     metrics_collections: An optional list of collections that the metric
   1335       value variable should be added to.
   1336     updates_collections: An optional list of collections that the metric update
   1337       ops should be added to.
   1338 
   1339   Returns:
   1340     value_tensor: A `Tensor` representing the current value of the metric.
   1341     update_op: An operation that accumulates the error from a batch of data.
   1342 
   1343   Raises:
   1344     ValueError: If `weights` is not `None` and its shape doesn't match `values`,
   1345       or if either `metrics_collections` or `updates_collections` are not a list
   1346       or tuple.
   1347   """
   1348   check_ops.assert_type(values, dtypes.bool)
   1349   count = metric_variable([], dtypes.float32, name='count')
   1350 
   1351   values = math_ops.to_float(values)
   1352   if weights is not None:
   1353     with ops.control_dependencies((check_ops.assert_rank_in(
   1354         weights, (0, array_ops.rank(values))),)):
   1355       weights = math_ops.to_float(weights)
   1356       values = math_ops.multiply(values, weights)
   1357 
   1358   value_tensor = array_ops.identity(count)
   1359   update_op = state_ops.assign_add(count, math_ops.reduce_sum(values))
   1360 
   1361   if metrics_collections:
   1362     ops.add_to_collections(metrics_collections, value_tensor)
   1363 
   1364   if updates_collections:
   1365     ops.add_to_collections(updates_collections, update_op)
   1366 
   1367   return value_tensor, update_op
   1368 
   1369 
   1370 @tf_export('metrics.false_negatives')
   1371 def false_negatives(labels,
   1372                     predictions,
   1373                     weights=None,
   1374                     metrics_collections=None,
   1375                     updates_collections=None,
   1376                     name=None):
   1377   """Computes the total number of false negatives.
   1378 
   1379   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1380 
   1381   Args:
   1382     labels: The ground truth values, a `Tensor` whose dimensions must match
   1383       `predictions`. Will be cast to `bool`.
   1384     predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
   1385       be cast to `bool`.
   1386     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1387       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1388       be either `1`, or the same as the corresponding `labels` dimension).
   1389     metrics_collections: An optional list of collections that the metric
   1390       value variable should be added to.
   1391     updates_collections: An optional list of collections that the metric update
   1392       ops should be added to.
   1393     name: An optional variable_scope name.
   1394 
   1395   Returns:
   1396     value_tensor: A `Tensor` representing the current value of the metric.
   1397     update_op: An operation that accumulates the error from a batch of data.
   1398 
   1399   Raises:
   1400     ValueError: If `weights` is not `None` and its shape doesn't match `values`,
   1401       or if either `metrics_collections` or `updates_collections` are not a list
   1402       or tuple.
   1403     RuntimeError: If eager execution is enabled.
   1404   """
   1405   if context.in_eager_mode():
   1406     raise RuntimeError('tf.metrics.false_negatives is not supported when '
   1407                        'eager execution is enabled.')
   1408 
   1409   with variable_scope.variable_scope(name, 'false_negatives',
   1410                                      (predictions, labels, weights)):
   1411 
   1412     predictions, labels, weights = _remove_squeezable_dimensions(
   1413         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
   1414         labels=math_ops.cast(labels, dtype=dtypes.bool),
   1415         weights=weights)
   1416     is_false_negative = math_ops.logical_and(
   1417         math_ops.equal(labels, True), math_ops.equal(predictions, False))
   1418     return _count_condition(is_false_negative, weights, metrics_collections,
   1419                             updates_collections)
   1420 
   1421 
   1422 @tf_export('metrics.false_negatives_at_thresholds')
   1423 def false_negatives_at_thresholds(labels,
   1424                                   predictions,
   1425                                   thresholds,
   1426                                   weights=None,
   1427                                   metrics_collections=None,
   1428                                   updates_collections=None,
   1429                                   name=None):
   1430   """Computes false negatives at provided threshold values.
   1431 
   1432   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1433 
   1434   Args:
   1435     labels: A `Tensor` whose shape matches `predictions`. Will be cast to
   1436       `bool`.
   1437     predictions: A floating point `Tensor` of arbitrary shape and whose values
   1438       are in the range `[0, 1]`.
   1439     thresholds: A python list or tuple of float thresholds in `[0, 1]`.
   1440     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1441       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1442       be either `1`, or the same as the corresponding `labels` dimension).
   1443     metrics_collections: An optional list of collections that `false_negatives`
   1444       should be added to.
   1445     updates_collections: An optional list of collections that `update_op` should
   1446       be added to.
   1447     name: An optional variable_scope name.
   1448 
   1449   Returns:
   1450     false_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
   1451     update_op: An operation that updates the `false_negatives` variable and
   1452       returns its current value.
   1453 
   1454   Raises:
   1455     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1456       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1457       either `metrics_collections` or `updates_collections` are not a list or
   1458       tuple.
   1459     RuntimeError: If eager execution is enabled.
   1460   """
   1461   if context.in_eager_mode():
   1462     raise RuntimeError('tf.metrics.false_negatives_at_thresholds is not '
   1463                        'supported when eager execution is enabled.')
   1464 
   1465   with variable_scope.variable_scope(name, 'false_negatives',
   1466                                      (predictions, labels, weights)):
   1467     values, update_ops = _confusion_matrix_at_thresholds(
   1468         labels, predictions, thresholds, weights=weights, includes=('fn',))
   1469 
   1470     if metrics_collections:
   1471       ops.add_to_collections(metrics_collections, values['fn'])
   1472 
   1473     if updates_collections:
   1474       ops.add_to_collections(updates_collections, update_ops['fn'])
   1475 
   1476     return values['fn'], update_ops['fn']
   1477 
   1478 
   1479 @tf_export('metrics.false_positives')
   1480 def false_positives(labels,
   1481                     predictions,
   1482                     weights=None,
   1483                     metrics_collections=None,
   1484                     updates_collections=None,
   1485                     name=None):
   1486   """Sum the weights of false positives.
   1487 
   1488   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1489 
   1490   Args:
   1491     labels: The ground truth values, a `Tensor` whose dimensions must match
   1492       `predictions`. Will be cast to `bool`.
   1493     predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
   1494       be cast to `bool`.
   1495     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1496       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1497       be either `1`, or the same as the corresponding `labels` dimension).
   1498     metrics_collections: An optional list of collections that the metric
   1499       value variable should be added to.
   1500     updates_collections: An optional list of collections that the metric update
   1501       ops should be added to.
   1502     name: An optional variable_scope name.
   1503 
   1504   Returns:
   1505     value_tensor: A `Tensor` representing the current value of the metric.
   1506     update_op: An operation that accumulates the error from a batch of data.
   1507 
   1508   Raises:
   1509     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1510       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1511       either `metrics_collections` or `updates_collections` are not a list or
   1512       tuple.
   1513     RuntimeError: If eager execution is enabled.
   1514   """
   1515   if context.in_eager_mode():
   1516     raise RuntimeError('tf.metrics.false_positives is not supported when '
   1517                        'eager execution is enabled.')
   1518 
   1519   with variable_scope.variable_scope(name, 'false_positives',
   1520                                      (predictions, labels, weights)):
   1521 
   1522     predictions, labels, weights = _remove_squeezable_dimensions(
   1523         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
   1524         labels=math_ops.cast(labels, dtype=dtypes.bool),
   1525         weights=weights)
   1526     is_false_positive = math_ops.logical_and(
   1527         math_ops.equal(labels, False), math_ops.equal(predictions, True))
   1528     return _count_condition(is_false_positive, weights, metrics_collections,
   1529                             updates_collections)
   1530 
   1531 
   1532 @tf_export('metrics.false_positives_at_thresholds')
   1533 def false_positives_at_thresholds(labels,
   1534                                   predictions,
   1535                                   thresholds,
   1536                                   weights=None,
   1537                                   metrics_collections=None,
   1538                                   updates_collections=None,
   1539                                   name=None):
   1540   """Computes false positives at provided threshold values.
   1541 
   1542   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1543 
   1544   Args:
   1545     labels: A `Tensor` whose shape matches `predictions`. Will be cast to
   1546       `bool`.
   1547     predictions: A floating point `Tensor` of arbitrary shape and whose values
   1548       are in the range `[0, 1]`.
   1549     thresholds: A python list or tuple of float thresholds in `[0, 1]`.
   1550     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1551       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1552       be either `1`, or the same as the corresponding `labels` dimension).
   1553     metrics_collections: An optional list of collections that `false_positives`
   1554       should be added to.
   1555     updates_collections: An optional list of collections that `update_op` should
   1556       be added to.
   1557     name: An optional variable_scope name.
   1558 
   1559   Returns:
   1560     false_positives:  A float `Tensor` of shape `[len(thresholds)]`.
   1561     update_op: An operation that updates the `false_positives` variable and
   1562       returns its current value.
   1563 
   1564   Raises:
   1565     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1566       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1567       either `metrics_collections` or `updates_collections` are not a list or
   1568       tuple.
   1569     RuntimeError: If eager execution is enabled.
   1570   """
   1571   if context.in_eager_mode():
   1572     raise RuntimeError('tf.metrics.false_positives_at_thresholds is not '
   1573                        'supported when eager execution is enabled.')
   1574 
   1575   with variable_scope.variable_scope(name, 'false_positives',
   1576                                      (predictions, labels, weights)):
   1577     values, update_ops = _confusion_matrix_at_thresholds(
   1578         labels, predictions, thresholds, weights=weights, includes=('fp',))
   1579 
   1580     if metrics_collections:
   1581       ops.add_to_collections(metrics_collections, values['fp'])
   1582 
   1583     if updates_collections:
   1584       ops.add_to_collections(updates_collections, update_ops['fp'])
   1585 
   1586     return values['fp'], update_ops['fp']
   1587 
   1588 
   1589 @tf_export('metrics.true_negatives')
   1590 def true_negatives(labels,
   1591                    predictions,
   1592                    weights=None,
   1593                    metrics_collections=None,
   1594                    updates_collections=None,
   1595                    name=None):
   1596   """Sum the weights of true_negatives.
   1597 
   1598   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1599 
   1600   Args:
   1601     labels: The ground truth values, a `Tensor` whose dimensions must match
   1602       `predictions`. Will be cast to `bool`.
   1603     predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
   1604       be cast to `bool`.
   1605     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1606       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1607       be either `1`, or the same as the corresponding `labels` dimension).
   1608     metrics_collections: An optional list of collections that the metric
   1609       value variable should be added to.
   1610     updates_collections: An optional list of collections that the metric update
   1611       ops should be added to.
   1612     name: An optional variable_scope name.
   1613 
   1614   Returns:
   1615     value_tensor: A `Tensor` representing the current value of the metric.
   1616     update_op: An operation that accumulates the error from a batch of data.
   1617 
   1618   Raises:
   1619     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1620       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1621       either `metrics_collections` or `updates_collections` are not a list or
   1622       tuple.
   1623     RuntimeError: If eager execution is enabled.
   1624   """
   1625   if context.in_eager_mode():
   1626     raise RuntimeError('tf.metrics.true_negatives is not '
   1627                        'supported when eager execution is enabled.')
   1628 
   1629   with variable_scope.variable_scope(name, 'true_negatives',
   1630                                      (predictions, labels, weights)):
   1631 
   1632     predictions, labels, weights = _remove_squeezable_dimensions(
   1633         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
   1634         labels=math_ops.cast(labels, dtype=dtypes.bool),
   1635         weights=weights)
   1636     is_true_negative = math_ops.logical_and(
   1637         math_ops.equal(labels, False), math_ops.equal(predictions, False))
   1638     return _count_condition(is_true_negative, weights, metrics_collections,
   1639                             updates_collections)
   1640 
   1641 
   1642 @tf_export('metrics.true_negatives_at_thresholds')
   1643 def true_negatives_at_thresholds(labels,
   1644                                  predictions,
   1645                                  thresholds,
   1646                                  weights=None,
   1647                                  metrics_collections=None,
   1648                                  updates_collections=None,
   1649                                  name=None):
   1650   """Computes true negatives at provided threshold values.
   1651 
   1652   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1653 
   1654   Args:
   1655     labels: A `Tensor` whose shape matches `predictions`. Will be cast to
   1656       `bool`.
   1657     predictions: A floating point `Tensor` of arbitrary shape and whose values
   1658       are in the range `[0, 1]`.
   1659     thresholds: A python list or tuple of float thresholds in `[0, 1]`.
   1660     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1661       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1662       be either `1`, or the same as the corresponding `labels` dimension).
   1663     metrics_collections: An optional list of collections that `true_negatives`
   1664       should be added to.
   1665     updates_collections: An optional list of collections that `update_op` should
   1666       be added to.
   1667     name: An optional variable_scope name.
   1668 
   1669   Returns:
   1670     true_negatives:  A float `Tensor` of shape `[len(thresholds)]`.
   1671     update_op: An operation that updates the `true_negatives` variable and
   1672       returns its current value.
   1673 
   1674   Raises:
   1675     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1676       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1677       either `metrics_collections` or `updates_collections` are not a list or
   1678       tuple.
   1679     RuntimeError: If eager execution is enabled.
   1680   """
   1681   if context.in_eager_mode():
   1682     raise RuntimeError('tf.metrics.true_negatives_at_thresholds is not '
   1683                        'supported when eager execution is enabled.')
   1684 
   1685   with variable_scope.variable_scope(name, 'true_negatives',
   1686                                      (predictions, labels, weights)):
   1687     values, update_ops = _confusion_matrix_at_thresholds(
   1688         labels, predictions, thresholds, weights=weights, includes=('tn',))
   1689 
   1690     if metrics_collections:
   1691       ops.add_to_collections(metrics_collections, values['tn'])
   1692 
   1693     if updates_collections:
   1694       ops.add_to_collections(updates_collections, update_ops['tn'])
   1695 
   1696     return values['tn'], update_ops['tn']
   1697 
   1698 
   1699 @tf_export('metrics.true_positives')
   1700 def true_positives(labels,
   1701                    predictions,
   1702                    weights=None,
   1703                    metrics_collections=None,
   1704                    updates_collections=None,
   1705                    name=None):
   1706   """Sum the weights of true_positives.
   1707 
   1708   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1709 
   1710   Args:
   1711     labels: The ground truth values, a `Tensor` whose dimensions must match
   1712       `predictions`. Will be cast to `bool`.
   1713     predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
   1714       be cast to `bool`.
   1715     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1716       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1717       be either `1`, or the same as the corresponding `labels` dimension).
   1718     metrics_collections: An optional list of collections that the metric
   1719       value variable should be added to.
   1720     updates_collections: An optional list of collections that the metric update
   1721       ops should be added to.
   1722     name: An optional variable_scope name.
   1723 
   1724   Returns:
   1725     value_tensor: A `Tensor` representing the current value of the metric.
   1726     update_op: An operation that accumulates the error from a batch of data.
   1727 
   1728   Raises:
   1729     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1730       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1731       either `metrics_collections` or `updates_collections` are not a list or
   1732       tuple.
   1733     RuntimeError: If eager execution is enabled.
   1734   """
   1735   if context.in_eager_mode():
   1736     raise RuntimeError('tf.metrics.true_positives is not '
   1737                        'supported when eager execution is enabled.')
   1738 
   1739   with variable_scope.variable_scope(name, 'true_positives',
   1740                                      (predictions, labels, weights)):
   1741 
   1742     predictions, labels, weights = _remove_squeezable_dimensions(
   1743         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
   1744         labels=math_ops.cast(labels, dtype=dtypes.bool),
   1745         weights=weights)
   1746     is_true_positive = math_ops.logical_and(
   1747         math_ops.equal(labels, True), math_ops.equal(predictions, True))
   1748     return _count_condition(is_true_positive, weights, metrics_collections,
   1749                             updates_collections)
   1750 
   1751 
   1752 @tf_export('metrics.true_positives_at_thresholds')
   1753 def true_positives_at_thresholds(labels,
   1754                                  predictions,
   1755                                  thresholds,
   1756                                  weights=None,
   1757                                  metrics_collections=None,
   1758                                  updates_collections=None,
   1759                                  name=None):
   1760   """Computes true positives at provided threshold values.
   1761 
   1762   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1763 
   1764   Args:
   1765     labels: A `Tensor` whose shape matches `predictions`. Will be cast to
   1766       `bool`.
   1767     predictions: A floating point `Tensor` of arbitrary shape and whose values
   1768       are in the range `[0, 1]`.
   1769     thresholds: A python list or tuple of float thresholds in `[0, 1]`.
   1770     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1771       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1772       be either `1`, or the same as the corresponding `labels` dimension).
   1773     metrics_collections: An optional list of collections that `true_positives`
   1774       should be added to.
   1775     updates_collections: An optional list of collections that `update_op` should
   1776       be added to.
   1777     name: An optional variable_scope name.
   1778 
   1779   Returns:
   1780     true_positives:  A float `Tensor` of shape `[len(thresholds)]`.
   1781     update_op: An operation that updates the `true_positives` variable and
   1782       returns its current value.
   1783 
   1784   Raises:
   1785     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1786       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1787       either `metrics_collections` or `updates_collections` are not a list or
   1788       tuple.
   1789     RuntimeError: If eager execution is enabled.
   1790   """
   1791   if context.in_eager_mode():
   1792     raise RuntimeError('tf.metrics.true_positives_at_thresholds is not '
   1793                        'supported when eager execution is enabled.')
   1794 
   1795   with variable_scope.variable_scope(name, 'true_positives',
   1796                                      (predictions, labels, weights)):
   1797     values, update_ops = _confusion_matrix_at_thresholds(
   1798         labels, predictions, thresholds, weights=weights, includes=('tp',))
   1799 
   1800     if metrics_collections:
   1801       ops.add_to_collections(metrics_collections, values['tp'])
   1802 
   1803     if updates_collections:
   1804       ops.add_to_collections(updates_collections, update_ops['tp'])
   1805 
   1806     return values['tp'], update_ops['tp']
   1807 
   1808 
   1809 @tf_export('metrics.precision')
   1810 def precision(labels,
   1811               predictions,
   1812               weights=None,
   1813               metrics_collections=None,
   1814               updates_collections=None,
   1815               name=None):
   1816   """Computes the precision of the predictions with respect to the labels.
   1817 
   1818   The `precision` function creates two local variables,
   1819   `true_positives` and `false_positives`, that are used to compute the
   1820   precision. This value is ultimately returned as `precision`, an idempotent
   1821   operation that simply divides `true_positives` by the sum of `true_positives`
   1822   and `false_positives`.
   1823 
   1824   For estimation of the metric over a stream of data, the function creates an
   1825   `update_op` operation that updates these variables and returns the
   1826   `precision`. `update_op` weights each prediction by the corresponding value in
   1827   `weights`.
   1828 
   1829   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1830 
   1831   Args:
   1832     labels: The ground truth values, a `Tensor` whose dimensions must match
   1833       `predictions`. Will be cast to `bool`.
   1834     predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
   1835       be cast to `bool`.
   1836     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1837       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1838       be either `1`, or the same as the corresponding `labels` dimension).
   1839     metrics_collections: An optional list of collections that `precision` should
   1840       be added to.
   1841     updates_collections: An optional list of collections that `update_op` should
   1842       be added to.
   1843     name: An optional variable_scope name.
   1844 
   1845   Returns:
   1846     precision: Scalar float `Tensor` with the value of `true_positives`
   1847       divided by the sum of `true_positives` and `false_positives`.
   1848     update_op: `Operation` that increments `true_positives` and
   1849       `false_positives` variables appropriately and whose value matches
   1850       `precision`.
   1851 
   1852   Raises:
   1853     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1854       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1855       either `metrics_collections` or `updates_collections` are not a list or
   1856       tuple.
   1857     RuntimeError: If eager execution is enabled.
   1858   """
   1859   if context.in_eager_mode():
   1860     raise RuntimeError('tf.metrics.precision is not '
   1861                        'supported when eager execution is enabled.')
   1862 
   1863   with variable_scope.variable_scope(name, 'precision',
   1864                                      (predictions, labels, weights)):
   1865 
   1866     predictions, labels, weights = _remove_squeezable_dimensions(
   1867         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
   1868         labels=math_ops.cast(labels, dtype=dtypes.bool),
   1869         weights=weights)
   1870 
   1871     true_p, true_positives_update_op = true_positives(
   1872         labels,
   1873         predictions,
   1874         weights,
   1875         metrics_collections=None,
   1876         updates_collections=None,
   1877         name=None)
   1878     false_p, false_positives_update_op = false_positives(
   1879         labels,
   1880         predictions,
   1881         weights,
   1882         metrics_collections=None,
   1883         updates_collections=None,
   1884         name=None)
   1885 
   1886     def compute_precision(tp, fp, name):
   1887       return array_ops.where(
   1888           math_ops.greater(tp + fp, 0), math_ops.div(tp, tp + fp), 0, name)
   1889 
   1890     p = compute_precision(true_p, false_p, 'value')
   1891     update_op = compute_precision(true_positives_update_op,
   1892                                   false_positives_update_op, 'update_op')
   1893 
   1894     if metrics_collections:
   1895       ops.add_to_collections(metrics_collections, p)
   1896 
   1897     if updates_collections:
   1898       ops.add_to_collections(updates_collections, update_op)
   1899 
   1900     return p, update_op
   1901 
   1902 
   1903 @tf_export('metrics.precision_at_thresholds')
   1904 def precision_at_thresholds(labels,
   1905                             predictions,
   1906                             thresholds,
   1907                             weights=None,
   1908                             metrics_collections=None,
   1909                             updates_collections=None,
   1910                             name=None):
   1911   """Computes precision values for different `thresholds` on `predictions`.
   1912 
   1913   The `precision_at_thresholds` function creates four local variables,
   1914   `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
   1915   for various values of thresholds. `precision[i]` is defined as the total
   1916   weight of values in `predictions` above `thresholds[i]` whose corresponding
   1917   entry in `labels` is `True`, divided by the total weight of values in
   1918   `predictions` above `thresholds[i]` (`true_positives[i] / (true_positives[i] +
   1919   false_positives[i])`).
   1920 
   1921   For estimation of the metric over a stream of data, the function creates an
   1922   `update_op` operation that updates these variables and returns the
   1923   `precision`.
   1924 
   1925   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   1926 
   1927   Args:
   1928     labels: The ground truth values, a `Tensor` whose dimensions must match
   1929       `predictions`. Will be cast to `bool`.
   1930     predictions: A floating point `Tensor` of arbitrary shape and whose values
   1931       are in the range `[0, 1]`.
   1932     thresholds: A python list or tuple of float thresholds in `[0, 1]`.
   1933     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   1934       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   1935       be either `1`, or the same as the corresponding `labels` dimension).
   1936     metrics_collections: An optional list of collections that `auc` should be
   1937       added to.
   1938     updates_collections: An optional list of collections that `update_op` should
   1939       be added to.
   1940     name: An optional variable_scope name.
   1941 
   1942   Returns:
   1943     precision: A float `Tensor` of shape `[len(thresholds)]`.
   1944     update_op: An operation that increments the `true_positives`,
   1945       `true_negatives`, `false_positives` and `false_negatives` variables that
   1946       are used in the computation of `precision`.
   1947 
   1948   Raises:
   1949     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   1950       `weights` is not `None` and its shape doesn't match `predictions`, or if
   1951       either `metrics_collections` or `updates_collections` are not a list or
   1952       tuple.
   1953     RuntimeError: If eager execution is enabled.
   1954   """
   1955   if context.in_eager_mode():
   1956     raise RuntimeError('tf.metrics.precision_at_thresholds is not '
   1957                        'supported when eager execution is enabled.')
   1958 
   1959   with variable_scope.variable_scope(name, 'precision_at_thresholds',
   1960                                      (predictions, labels, weights)):
   1961     values, update_ops = _confusion_matrix_at_thresholds(
   1962         labels, predictions, thresholds, weights, includes=('tp', 'fp'))
   1963 
   1964     # Avoid division by zero.
   1965     epsilon = 1e-7
   1966 
   1967     def compute_precision(tp, fp, name):
   1968       return math_ops.div(tp, epsilon + tp + fp, name='precision_' + name)
   1969 
   1970     prec = compute_precision(values['tp'], values['fp'], 'value')
   1971     update_op = compute_precision(update_ops['tp'], update_ops['fp'],
   1972                                   'update_op')
   1973 
   1974     if metrics_collections:
   1975       ops.add_to_collections(metrics_collections, prec)
   1976 
   1977     if updates_collections:
   1978       ops.add_to_collections(updates_collections, update_op)
   1979 
   1980     return prec, update_op
   1981 
   1982 
   1983 @tf_export('metrics.recall')
   1984 def recall(labels,
   1985            predictions,
   1986            weights=None,
   1987            metrics_collections=None,
   1988            updates_collections=None,
   1989            name=None):
   1990   """Computes the recall of the predictions with respect to the labels.
   1991 
   1992   The `recall` function creates two local variables, `true_positives`
   1993   and `false_negatives`, that are used to compute the recall. This value is
   1994   ultimately returned as `recall`, an idempotent operation that simply divides
   1995   `true_positives` by the sum of `true_positives`  and `false_negatives`.
   1996 
   1997   For estimation of the metric over a stream of data, the function creates an
   1998   `update_op` that updates these variables and returns the `recall`. `update_op`
   1999   weights each prediction by the corresponding value in `weights`.
   2000 
   2001   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   2002 
   2003   Args:
   2004     labels: The ground truth values, a `Tensor` whose dimensions must match
   2005       `predictions`. Will be cast to `bool`.
   2006     predictions: The predicted values, a `Tensor` of arbitrary dimensions. Will
   2007       be cast to `bool`.
   2008     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   2009       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   2010       be either `1`, or the same as the corresponding `labels` dimension).
   2011     metrics_collections: An optional list of collections that `recall` should
   2012       be added to.
   2013     updates_collections: An optional list of collections that `update_op` should
   2014       be added to.
   2015     name: An optional variable_scope name.
   2016 
   2017   Returns:
   2018     recall: Scalar float `Tensor` with the value of `true_positives` divided
   2019       by the sum of `true_positives` and `false_negatives`.
   2020     update_op: `Operation` that increments `true_positives` and
   2021       `false_negatives` variables appropriately and whose value matches
   2022       `recall`.
   2023 
   2024   Raises:
   2025     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   2026       `weights` is not `None` and its shape doesn't match `predictions`, or if
   2027       either `metrics_collections` or `updates_collections` are not a list or
   2028       tuple.
   2029     RuntimeError: If eager execution is enabled.
   2030   """
   2031   if context.in_eager_mode():
   2032     raise RuntimeError('tf.metrics.recall is not supported is not '
   2033                        'supported when eager execution is enabled.')
   2034 
   2035   with variable_scope.variable_scope(name, 'recall',
   2036                                      (predictions, labels, weights)):
   2037     predictions, labels, weights = _remove_squeezable_dimensions(
   2038         predictions=math_ops.cast(predictions, dtype=dtypes.bool),
   2039         labels=math_ops.cast(labels, dtype=dtypes.bool),
   2040         weights=weights)
   2041 
   2042     true_p, true_positives_update_op = true_positives(
   2043         labels,
   2044         predictions,
   2045         weights,
   2046         metrics_collections=None,
   2047         updates_collections=None,
   2048         name=None)
   2049     false_n, false_negatives_update_op = false_negatives(
   2050         labels,
   2051         predictions,
   2052         weights,
   2053         metrics_collections=None,
   2054         updates_collections=None,
   2055         name=None)
   2056 
   2057     def compute_recall(true_p, false_n, name):
   2058       return array_ops.where(
   2059           math_ops.greater(true_p + false_n, 0),
   2060           math_ops.div(true_p, true_p + false_n), 0, name)
   2061 
   2062     rec = compute_recall(true_p, false_n, 'value')
   2063     update_op = compute_recall(true_positives_update_op,
   2064                                false_negatives_update_op, 'update_op')
   2065 
   2066     if metrics_collections:
   2067       ops.add_to_collections(metrics_collections, rec)
   2068 
   2069     if updates_collections:
   2070       ops.add_to_collections(updates_collections, update_op)
   2071 
   2072     return rec, update_op
   2073 
   2074 
   2075 def _at_k_name(name, k=None, class_id=None):
   2076   if k is not None:
   2077     name = '%s_at_%d' % (name, k)
   2078   else:
   2079     name = '%s_at_k' % (name)
   2080   if class_id is not None:
   2081     name = '%s_class%d' % (name, class_id)
   2082   return name
   2083 
   2084 
   2085 def _select_class_id(ids, selected_id):
   2086   """Filter all but `selected_id` out of `ids`.
   2087 
   2088   Args:
   2089     ids: `int64` `Tensor` or `SparseTensor` of IDs.
   2090     selected_id: Int id to select.
   2091 
   2092   Returns:
   2093     `SparseTensor` of same dimensions as `ids`. This contains only the entries
   2094     equal to `selected_id`.
   2095   """
   2096   ids = sparse_tensor.convert_to_tensor_or_sparse_tensor(ids)
   2097   if isinstance(ids, sparse_tensor.SparseTensor):
   2098     return sparse_ops.sparse_retain(ids, math_ops.equal(ids.values,
   2099                                                         selected_id))
   2100 
   2101   # TODO(ptucker): Make this more efficient, maybe add a sparse version of
   2102   # tf.equal and tf.reduce_any?
   2103 
   2104   # Shape of filled IDs is the same as `ids` with the last dim collapsed to 1.
   2105   ids_shape = array_ops.shape(ids, out_type=dtypes.int64)
   2106   ids_last_dim = array_ops.size(ids_shape) - 1
   2107   filled_selected_id_shape = math_ops.reduced_shape(ids_shape,
   2108                                                     array_ops.reshape(
   2109                                                         ids_last_dim, [1]))
   2110 
   2111   # Intersect `ids` with the selected ID.
   2112   filled_selected_id = array_ops.fill(filled_selected_id_shape,
   2113                                       math_ops.to_int64(selected_id))
   2114   result = sets.set_intersection(filled_selected_id, ids)
   2115   return sparse_tensor.SparseTensor(
   2116       indices=result.indices, values=result.values, dense_shape=ids_shape)
   2117 
   2118 
   2119 def _maybe_select_class_id(labels, predictions_idx, selected_id=None):
   2120   """If class ID is specified, filter all other classes.
   2121 
   2122   Args:
   2123     labels: `int64` `Tensor` or `SparseTensor` with shape
   2124       [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
   2125       target classes for the associated prediction. Commonly, N=1 and `labels`
   2126       has shape [batch_size, num_labels]. [D1, ... DN] must match
   2127       `predictions_idx`.
   2128     predictions_idx: `int64` `Tensor` of class IDs, with shape [D1, ... DN, k]
   2129       where N >= 1. Commonly, N=1 and `predictions_idx` has shape
   2130       [batch size, k].
   2131     selected_id: Int id to select.
   2132 
   2133   Returns:
   2134     Tuple of `labels` and `predictions_idx`, possibly with classes removed.
   2135   """
   2136   if selected_id is None:
   2137     return labels, predictions_idx
   2138   return (_select_class_id(labels, selected_id),
   2139           _select_class_id(predictions_idx, selected_id))
   2140 
   2141 
   2142 def _sparse_true_positive_at_k(labels,
   2143                                predictions_idx,
   2144                                class_id=None,
   2145                                weights=None,
   2146                                name=None):
   2147   """Calculates true positives for recall@k and precision@k.
   2148 
   2149   If `class_id` is specified, calculate binary true positives for `class_id`
   2150       only.
   2151   If `class_id` is not specified, calculate metrics for `k` predicted vs
   2152       `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
   2153 
   2154   Args:
   2155     labels: `int64` `Tensor` or `SparseTensor` with shape
   2156       [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
   2157       target classes for the associated prediction. Commonly, N=1 and `labels`
   2158       has shape [batch_size, num_labels]. [D1, ... DN] must match
   2159       `predictions_idx`.
   2160     predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
   2161       top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
   2162       match `labels`.
   2163     class_id: Class for which we want binary metrics.
   2164     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   2165       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   2166       dimensions must be either `1`, or the same as the corresponding `labels`
   2167       dimension).
   2168     name: Name of operation.
   2169 
   2170   Returns:
   2171     A [D1, ... DN] `Tensor` of true positive counts.
   2172   """
   2173   with ops.name_scope(name, 'true_positives',
   2174                       (predictions_idx, labels, weights)):
   2175     labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
   2176                                                      class_id)
   2177     tp = sets.set_size(sets.set_intersection(predictions_idx, labels))
   2178     tp = math_ops.to_double(tp)
   2179     if weights is not None:
   2180       with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
   2181           weights, tp),)):
   2182         weights = math_ops.to_double(weights)
   2183         tp = math_ops.multiply(tp, weights)
   2184     return tp
   2185 
   2186 
   2187 def _streaming_sparse_true_positive_at_k(labels,
   2188                                          predictions_idx,
   2189                                          k=None,
   2190                                          class_id=None,
   2191                                          weights=None,
   2192                                          name=None):
   2193   """Calculates weighted per step true positives for recall@k and precision@k.
   2194 
   2195   If `class_id` is specified, calculate binary true positives for `class_id`
   2196       only.
   2197   If `class_id` is not specified, calculate metrics for `k` predicted vs
   2198       `n` label classes, where `n` is the 2nd dimension of `labels`.
   2199 
   2200   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   2201 
   2202   Args:
   2203     labels: `int64` `Tensor` or `SparseTensor` with shape
   2204       [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
   2205       target classes for the associated prediction. Commonly, N=1 and `labels`
   2206       has shape [batch_size, num_labels]. [D1, ... DN] must match
   2207       `predictions_idx`.
   2208     predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
   2209       top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
   2210       match `labels`.
   2211     k: Integer, k for @k metric. This is only used for default op name.
   2212     class_id: Class for which we want binary metrics.
   2213     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   2214       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   2215       dimensions must be either `1`, or the same as the corresponding `labels`
   2216       dimension).
   2217     name: Name of new variable, and namespace for other dependent ops.
   2218 
   2219   Returns:
   2220     A tuple of `Variable` and update `Operation`.
   2221 
   2222   Raises:
   2223     ValueError: If `weights` is not `None` and has an incompatible shape.
   2224   """
   2225   with ops.name_scope(name, _at_k_name('true_positive', k, class_id=class_id),
   2226                       (predictions_idx, labels, weights)) as scope:
   2227     tp = _sparse_true_positive_at_k(
   2228         predictions_idx=predictions_idx,
   2229         labels=labels,
   2230         class_id=class_id,
   2231         weights=weights)
   2232     batch_total_tp = math_ops.to_double(math_ops.reduce_sum(tp))
   2233 
   2234     var = metric_variable([], dtypes.float64, name=scope)
   2235     return var, state_ops.assign_add(var, batch_total_tp, name='update')
   2236 
   2237 
   2238 def _sparse_false_negative_at_k(labels,
   2239                                 predictions_idx,
   2240                                 class_id=None,
   2241                                 weights=None):
   2242   """Calculates false negatives for recall@k.
   2243 
   2244   If `class_id` is specified, calculate binary true positives for `class_id`
   2245       only.
   2246   If `class_id` is not specified, calculate metrics for `k` predicted vs
   2247       `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
   2248 
   2249   Args:
   2250     labels: `int64` `Tensor` or `SparseTensor` with shape
   2251       [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
   2252       target classes for the associated prediction. Commonly, N=1 and `labels`
   2253       has shape [batch_size, num_labels]. [D1, ... DN] must match
   2254       `predictions_idx`.
   2255     predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
   2256       top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
   2257       match `labels`.
   2258     class_id: Class for which we want binary metrics.
   2259     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   2260       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   2261       dimensions must be either `1`, or the same as the corresponding `labels`
   2262       dimension).
   2263 
   2264   Returns:
   2265     A [D1, ... DN] `Tensor` of false negative counts.
   2266   """
   2267   with ops.name_scope(None, 'false_negatives',
   2268                       (predictions_idx, labels, weights)):
   2269     labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
   2270                                                      class_id)
   2271     fn = sets.set_size(
   2272         sets.set_difference(predictions_idx, labels, aminusb=False))
   2273     fn = math_ops.to_double(fn)
   2274     if weights is not None:
   2275       with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
   2276           weights, fn),)):
   2277         weights = math_ops.to_double(weights)
   2278         fn = math_ops.multiply(fn, weights)
   2279     return fn
   2280 
   2281 
   2282 def _streaming_sparse_false_negative_at_k(labels,
   2283                                           predictions_idx,
   2284                                           k,
   2285                                           class_id=None,
   2286                                           weights=None,
   2287                                           name=None):
   2288   """Calculates weighted per step false negatives for recall@k.
   2289 
   2290   If `class_id` is specified, calculate binary true positives for `class_id`
   2291       only.
   2292   If `class_id` is not specified, calculate metrics for `k` predicted vs
   2293       `n` label classes, where `n` is the 2nd dimension of `labels`.
   2294 
   2295   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   2296 
   2297   Args:
   2298     labels: `int64` `Tensor` or `SparseTensor` with shape
   2299       [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
   2300       target classes for the associated prediction. Commonly, N=1 and `labels`
   2301       has shape [batch_size, num_labels]. [D1, ... DN] must match
   2302       `predictions_idx`.
   2303     predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
   2304       top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
   2305       match `labels`.
   2306     k: Integer, k for @k metric. This is only used for default op name.
   2307     class_id: Class for which we want binary metrics.
   2308     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   2309       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   2310       dimensions must be either `1`, or the same as the corresponding `labels`
   2311       dimension).
   2312     name: Name of new variable, and namespace for other dependent ops.
   2313 
   2314   Returns:
   2315     A tuple of `Variable` and update `Operation`.
   2316 
   2317   Raises:
   2318     ValueError: If `weights` is not `None` and has an incompatible shape.
   2319   """
   2320   with ops.name_scope(name, _at_k_name('false_negative', k, class_id=class_id),
   2321                       (predictions_idx, labels, weights)) as scope:
   2322     fn = _sparse_false_negative_at_k(
   2323         predictions_idx=predictions_idx,
   2324         labels=labels,
   2325         class_id=class_id,
   2326         weights=weights)
   2327     batch_total_fn = math_ops.to_double(math_ops.reduce_sum(fn))
   2328 
   2329     var = metric_variable([], dtypes.float64, name=scope)
   2330     return var, state_ops.assign_add(var, batch_total_fn, name='update')
   2331 
   2332 
   2333 @tf_export('metrics.recall_at_k')
   2334 def recall_at_k(labels,
   2335                 predictions,
   2336                 k,
   2337                 class_id=None,
   2338                 weights=None,
   2339                 metrics_collections=None,
   2340                 updates_collections=None,
   2341                 name=None):
   2342   """Computes recall@k of the predictions with respect to sparse labels.
   2343 
   2344   If `class_id` is specified, we calculate recall by considering only the
   2345       entries in the batch for which `class_id` is in the label, and computing
   2346       the fraction of them for which `class_id` is in the top-k `predictions`.
   2347   If `class_id` is not specified, we'll calculate recall as how often on
   2348       average a class among the labels of a batch entry is in the top-k
   2349       `predictions`.
   2350 
   2351   `sparse_recall_at_k` creates two local variables,
   2352   `true_positive_at_<k>` and `false_negative_at_<k>`, that are used to compute
   2353   the recall_at_k frequency. This frequency is ultimately returned as
   2354   `recall_at_<k>`: an idempotent operation that simply divides
   2355   `true_positive_at_<k>` by total (`true_positive_at_<k>` +
   2356   `false_negative_at_<k>`).
   2357 
   2358   For estimation of the metric over a stream of data, the function creates an
   2359   `update_op` operation that updates these variables and returns the
   2360   `recall_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
   2361   indicating the top `k` `predictions`. Set operations applied to `top_k` and
   2362   `labels` calculate the true positives and false negatives weighted by
   2363   `weights`. Then `update_op` increments `true_positive_at_<k>` and
   2364   `false_negative_at_<k>` using these values.
   2365 
   2366   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   2367 
   2368   Args:
   2369     labels: `int64` `Tensor` or `SparseTensor` with shape
   2370       [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
   2371       num_labels=1. N >= 1 and num_labels is the number of target classes for
   2372       the associated prediction. Commonly, N=1 and `labels` has shape
   2373       [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
   2374       should be in range [0, num_classes), where num_classes is the last
   2375       dimension of `predictions`. Values outside this range always count
   2376       towards `false_negative_at_<k>`.
   2377     predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
   2378       N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
   2379       The final dimension contains the logit values for each class. [D1, ... DN]
   2380       must match `labels`.
   2381     k: Integer, k for @k metric.
   2382     class_id: Integer class ID for which we want binary metrics. This should be
   2383       in range [0, num_classes), where num_classes is the last dimension of
   2384       `predictions`. If class_id is outside this range, the method returns NAN.
   2385     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   2386       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   2387       dimensions must be either `1`, or the same as the corresponding `labels`
   2388       dimension).
   2389     metrics_collections: An optional list of collections that values should
   2390       be added to.
   2391     updates_collections: An optional list of collections that updates should
   2392       be added to.
   2393     name: Name of new update operation, and namespace for other dependent ops.
   2394 
   2395   Returns:
   2396     recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
   2397       by the sum of `true_positives` and `false_negatives`.
   2398     update_op: `Operation` that increments `true_positives` and
   2399       `false_negatives` variables appropriately, and whose value matches
   2400       `recall`.
   2401 
   2402   Raises:
   2403     ValueError: If `weights` is not `None` and its shape doesn't match
   2404     `predictions`, or if either `metrics_collections` or `updates_collections`
   2405     are not a list or tuple.
   2406     RuntimeError: If eager execution is enabled.
   2407   """
   2408   if context.in_eager_mode():
   2409     raise RuntimeError('tf.metrics.recall_at_k is not '
   2410                        'supported when eager execution is enabled.')
   2411 
   2412   with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
   2413                       (predictions, labels, weights)) as scope:
   2414     _, top_k_idx = nn.top_k(predictions, k)
   2415     return recall_at_top_k(
   2416         labels=labels,
   2417         predictions_idx=top_k_idx,
   2418         k=k,
   2419         class_id=class_id,
   2420         weights=weights,
   2421         metrics_collections=metrics_collections,
   2422         updates_collections=updates_collections,
   2423         name=scope)
   2424 
   2425 
   2426 @tf_export('metrics.recall_at_top_k')
   2427 def recall_at_top_k(labels,
   2428                     predictions_idx,
   2429                     k=None,
   2430                     class_id=None,
   2431                     weights=None,
   2432                     metrics_collections=None,
   2433                     updates_collections=None,
   2434                     name=None):
   2435   """Computes recall@k of top-k predictions with respect to sparse labels.
   2436 
   2437   Differs from `recall_at_k` in that predictions must be in the form of top `k`
   2438   class indices, whereas `recall_at_k` expects logits. Refer to `recall_at_k`
   2439   for more details.
   2440 
   2441   Args:
   2442     labels: `int64` `Tensor` or `SparseTensor` with shape
   2443       [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
   2444       num_labels=1. N >= 1 and num_labels is the number of target classes for
   2445       the associated prediction. Commonly, N=1 and `labels` has shape
   2446       [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
   2447       should be in range [0, num_classes), where num_classes is the last
   2448       dimension of `predictions`. Values outside this range always count
   2449       towards `false_negative_at_<k>`.
   2450     predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
   2451       Commonly, N=1 and predictions has shape [batch size, k]. The final
   2452       dimension contains the top `k` predicted class indices. [D1, ... DN] must
   2453       match `labels`.
   2454     k: Integer, k for @k metric. Only used for the default op name.
   2455     class_id: Integer class ID for which we want binary metrics. This should be
   2456       in range [0, num_classes), where num_classes is the last dimension of
   2457       `predictions`. If class_id is outside this range, the method returns NAN.
   2458     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   2459       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   2460       dimensions must be either `1`, or the same as the corresponding `labels`
   2461       dimension).
   2462     metrics_collections: An optional list of collections that values should
   2463       be added to.
   2464     updates_collections: An optional list of collections that updates should
   2465       be added to.
   2466     name: Name of new update operation, and namespace for other dependent ops.
   2467 
   2468   Returns:
   2469     recall: Scalar `float64` `Tensor` with the value of `true_positives` divided
   2470       by the sum of `true_positives` and `false_negatives`.
   2471     update_op: `Operation` that increments `true_positives` and
   2472       `false_negatives` variables appropriately, and whose value matches
   2473       `recall`.
   2474 
   2475   Raises:
   2476     ValueError: If `weights` is not `None` and its shape doesn't match
   2477     `predictions`, or if either `metrics_collections` or `updates_collections`
   2478     are not a list or tuple.
   2479   """
   2480   with ops.name_scope(name, _at_k_name('recall', k, class_id=class_id),
   2481                       (predictions_idx, labels, weights)) as scope:
   2482     labels = _maybe_expand_labels(labels, predictions_idx)
   2483     top_k_idx = math_ops.to_int64(predictions_idx)
   2484     tp, tp_update = _streaming_sparse_true_positive_at_k(
   2485         predictions_idx=top_k_idx,
   2486         labels=labels,
   2487         k=k,
   2488         class_id=class_id,
   2489         weights=weights)
   2490     fn, fn_update = _streaming_sparse_false_negative_at_k(
   2491         predictions_idx=top_k_idx,
   2492         labels=labels,
   2493         k=k,
   2494         class_id=class_id,
   2495         weights=weights)
   2496 
   2497     metric = math_ops.div(tp, math_ops.add(tp, fn), name=scope)
   2498     update = math_ops.div(
   2499         tp_update, math_ops.add(tp_update, fn_update), name='update')
   2500     if metrics_collections:
   2501       ops.add_to_collections(metrics_collections, metric)
   2502     if updates_collections:
   2503       ops.add_to_collections(updates_collections, update)
   2504     return metric, update
   2505 
   2506 
   2507 @tf_export('metrics.recall_at_thresholds')
   2508 def recall_at_thresholds(labels,
   2509                          predictions,
   2510                          thresholds,
   2511                          weights=None,
   2512                          metrics_collections=None,
   2513                          updates_collections=None,
   2514                          name=None):
   2515   """Computes various recall values for different `thresholds` on `predictions`.
   2516 
   2517   The `recall_at_thresholds` function creates four local variables,
   2518   `true_positives`, `true_negatives`, `false_positives` and `false_negatives`
   2519   for various values of thresholds. `recall[i]` is defined as the total weight
   2520   of values in `predictions` above `thresholds[i]` whose corresponding entry in
   2521   `labels` is `True`, divided by the total weight of `True` values in `labels`
   2522   (`true_positives[i] / (true_positives[i] + false_negatives[i])`).
   2523 
   2524   For estimation of the metric over a stream of data, the function creates an
   2525   `update_op` operation that updates these variables and returns the `recall`.
   2526 
   2527   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   2528 
   2529   Args:
   2530     labels: The ground truth values, a `Tensor` whose dimensions must match
   2531       `predictions`. Will be cast to `bool`.
   2532     predictions: A floating point `Tensor` of arbitrary shape and whose values
   2533       are in the range `[0, 1]`.
   2534     thresholds: A python list or tuple of float thresholds in `[0, 1]`.
   2535     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   2536       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   2537       be either `1`, or the same as the corresponding `labels` dimension).
   2538     metrics_collections: An optional list of collections that `recall` should be
   2539       added to.
   2540     updates_collections: An optional list of collections that `update_op` should
   2541       be added to.
   2542     name: An optional variable_scope name.
   2543 
   2544   Returns:
   2545     recall: A float `Tensor` of shape `[len(thresholds)]`.
   2546     update_op: An operation that increments the `true_positives`,
   2547       `true_negatives`, `false_positives` and `false_negatives` variables that
   2548       are used in the computation of `recall`.
   2549 
   2550   Raises:
   2551     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   2552       `weights` is not `None` and its shape doesn't match `predictions`, or if
   2553       either `metrics_collections` or `updates_collections` are not a list or
   2554       tuple.
   2555     RuntimeError: If eager execution is enabled.
   2556   """
   2557   if context.in_eager_mode():
   2558     raise RuntimeError('tf.metrics.recall_at_thresholds is not '
   2559                        'supported when eager execution is enabled.')
   2560 
   2561   with variable_scope.variable_scope(name, 'recall_at_thresholds',
   2562                                      (predictions, labels, weights)):
   2563     values, update_ops = _confusion_matrix_at_thresholds(
   2564         labels, predictions, thresholds, weights, includes=('tp', 'fn'))
   2565 
   2566     # Avoid division by zero.
   2567     epsilon = 1e-7
   2568 
   2569     def compute_recall(tp, fn, name):
   2570       return math_ops.div(tp, epsilon + tp + fn, name='recall_' + name)
   2571 
   2572     rec = compute_recall(values['tp'], values['fn'], 'value')
   2573     update_op = compute_recall(update_ops['tp'], update_ops['fn'], 'update_op')
   2574 
   2575     if metrics_collections:
   2576       ops.add_to_collections(metrics_collections, rec)
   2577 
   2578     if updates_collections:
   2579       ops.add_to_collections(updates_collections, update_op)
   2580 
   2581     return rec, update_op
   2582 
   2583 
   2584 @tf_export('metrics.root_mean_squared_error')
   2585 def root_mean_squared_error(labels,
   2586                             predictions,
   2587                             weights=None,
   2588                             metrics_collections=None,
   2589                             updates_collections=None,
   2590                             name=None):
   2591   """Computes the root mean squared error between the labels and predictions.
   2592 
   2593   The `root_mean_squared_error` function creates two local variables,
   2594   `total` and `count` that are used to compute the root mean squared error.
   2595   This average is weighted by `weights`, and it is ultimately returned as
   2596   `root_mean_squared_error`: an idempotent operation that takes the square root
   2597   of the division of `total` by `count`.
   2598 
   2599   For estimation of the metric over a stream of data, the function creates an
   2600   `update_op` operation that updates these variables and returns the
   2601   `root_mean_squared_error`. Internally, a `squared_error` operation computes
   2602   the element-wise square of the difference between `predictions` and `labels`.
   2603   Then `update_op` increments `total` with the reduced sum of the product of
   2604   `weights` and `squared_error`, and it increments `count` with the reduced sum
   2605   of `weights`.
   2606 
   2607   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   2608 
   2609   Args:
   2610     labels: A `Tensor` of the same shape as `predictions`.
   2611     predictions: A `Tensor` of arbitrary shape.
   2612     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   2613       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   2614       be either `1`, or the same as the corresponding `labels` dimension).
   2615     metrics_collections: An optional list of collections that
   2616       `root_mean_squared_error` should be added to.
   2617     updates_collections: An optional list of collections that `update_op` should
   2618       be added to.
   2619     name: An optional variable_scope name.
   2620 
   2621   Returns:
   2622     root_mean_squared_error: A `Tensor` representing the current mean, the value
   2623       of `total` divided by `count`.
   2624     update_op: An operation that increments the `total` and `count` variables
   2625       appropriately and whose value matches `root_mean_squared_error`.
   2626 
   2627   Raises:
   2628     ValueError: If `predictions` and `labels` have mismatched shapes, or if
   2629       `weights` is not `None` and its shape doesn't match `predictions`, or if
   2630       either `metrics_collections` or `updates_collections` are not a list or
   2631       tuple.
   2632     RuntimeError: If eager execution is enabled.
   2633   """
   2634   if context.in_eager_mode():
   2635     raise RuntimeError('tf.metrics.root_mean_squared_error is not '
   2636                        'supported when eager execution is enabled.')
   2637 
   2638   predictions, labels, weights = _remove_squeezable_dimensions(
   2639       predictions=predictions, labels=labels, weights=weights)
   2640   mse, update_mse_op = mean_squared_error(labels, predictions, weights, None,
   2641                                           None, name or
   2642                                           'root_mean_squared_error')
   2643 
   2644   rmse = math_ops.sqrt(mse)
   2645   update_rmse_op = math_ops.sqrt(update_mse_op)
   2646 
   2647   if metrics_collections:
   2648     ops.add_to_collections(metrics_collections, rmse)
   2649 
   2650   if updates_collections:
   2651     ops.add_to_collections(updates_collections, update_rmse_op)
   2652 
   2653   return rmse, update_rmse_op
   2654 
   2655 
   2656 @tf_export('metrics.sensitivity_at_specificity')
   2657 def sensitivity_at_specificity(labels,
   2658                                predictions,
   2659                                specificity,
   2660                                weights=None,
   2661                                num_thresholds=200,
   2662                                metrics_collections=None,
   2663                                updates_collections=None,
   2664                                name=None):
   2665   """Computes the specificity at a given sensitivity.
   2666 
   2667   The `sensitivity_at_specificity` function creates four local
   2668   variables, `true_positives`, `true_negatives`, `false_positives` and
   2669   `false_negatives` that are used to compute the sensitivity at the given
   2670   specificity value. The threshold for the given specificity value is computed
   2671   and used to evaluate the corresponding sensitivity.
   2672 
   2673   For estimation of the metric over a stream of data, the function creates an
   2674   `update_op` operation that updates these variables and returns the
   2675   `sensitivity`. `update_op` increments the `true_positives`, `true_negatives`,
   2676   `false_positives` and `false_negatives` counts with the weight of each case
   2677   found in the `predictions` and `labels`.
   2678 
   2679   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   2680 
   2681   For additional information about specificity and sensitivity, see the
   2682   following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
   2683 
   2684   Args:
   2685     labels: The ground truth values, a `Tensor` whose dimensions must match
   2686       `predictions`. Will be cast to `bool`.
   2687     predictions: A floating point `Tensor` of arbitrary shape and whose values
   2688       are in the range `[0, 1]`.
   2689     specificity: A scalar value in range `[0, 1]`.
   2690     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   2691       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   2692       be either `1`, or the same as the corresponding `labels` dimension).
   2693     num_thresholds: The number of thresholds to use for matching the given
   2694       specificity.
   2695     metrics_collections: An optional list of collections that `sensitivity`
   2696       should be added to.
   2697     updates_collections: An optional list of collections that `update_op` should
   2698       be added to.
   2699     name: An optional variable_scope name.
   2700 
   2701   Returns:
   2702     sensitivity: A scalar `Tensor` representing the sensitivity at the given
   2703       `specificity` value.
   2704     update_op: An operation that increments the `true_positives`,
   2705       `true_negatives`, `false_positives` and `false_negatives` variables
   2706       appropriately and whose value matches `sensitivity`.
   2707 
   2708   Raises:
   2709     ValueError: If `predictions` and `labels` have mismatched shapes, if
   2710       `weights` is not `None` and its shape doesn't match `predictions`, or if
   2711       `specificity` is not between 0 and 1, or if either `metrics_collections`
   2712       or `updates_collections` are not a list or tuple.
   2713     RuntimeError: If eager execution is enabled.
   2714   """
   2715   if context.in_eager_mode():
   2716     raise RuntimeError('tf.metrics.sensitivity_at_specificity is not '
   2717                        'supported when eager execution is enabled.')
   2718 
   2719   if specificity < 0 or specificity > 1:
   2720     raise ValueError('`specificity` must be in the range [0, 1].')
   2721 
   2722   with variable_scope.variable_scope(name, 'sensitivity_at_specificity',
   2723                                      (predictions, labels, weights)):
   2724     kepsilon = 1e-7  # to account for floating point imprecisions
   2725     thresholds = [
   2726         (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
   2727     ]
   2728     thresholds = [0.0 - kepsilon] + thresholds + [1.0 + kepsilon]
   2729 
   2730     values, update_ops = _confusion_matrix_at_thresholds(
   2731         labels, predictions, thresholds, weights)
   2732 
   2733     def compute_sensitivity_at_specificity(tp, tn, fp, fn, name):
   2734       specificities = math_ops.div(tn, tn + fp + kepsilon)
   2735       tf_index = math_ops.argmin(math_ops.abs(specificities - specificity), 0)
   2736       tf_index = math_ops.cast(tf_index, dtypes.int32)
   2737 
   2738       # Now, we have the implicit threshold, so compute the sensitivity:
   2739       return math_ops.div(tp[tf_index], tp[tf_index] + fn[tf_index] + kepsilon,
   2740                           name)
   2741 
   2742     sensitivity = compute_sensitivity_at_specificity(
   2743         values['tp'], values['tn'], values['fp'], values['fn'], 'value')
   2744     update_op = compute_sensitivity_at_specificity(
   2745         update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
   2746         'update_op')
   2747 
   2748     if metrics_collections:
   2749       ops.add_to_collections(metrics_collections, sensitivity)
   2750 
   2751     if updates_collections:
   2752       ops.add_to_collections(updates_collections, update_op)
   2753 
   2754     return sensitivity, update_op
   2755 
   2756 
   2757 def _expand_and_tile(tensor, multiple, dim=0, name=None):
   2758   """Slice `tensor` shape in 2, then tile along the sliced dimension.
   2759 
   2760   A new dimension is inserted in shape of `tensor` before `dim`, then values are
   2761   tiled `multiple` times along the new dimension.
   2762 
   2763   Args:
   2764     tensor: Input `Tensor` or `SparseTensor`.
   2765     multiple: Integer, number of times to tile.
   2766     dim: Integer, dimension along which to tile.
   2767     name: Name of operation.
   2768 
   2769   Returns:
   2770     `Tensor` result of expanding and tiling `tensor`.
   2771 
   2772   Raises:
   2773     ValueError: if `multiple` is less than 1, or `dim` is not in
   2774     `[-rank(tensor), rank(tensor)]`.
   2775   """
   2776   if multiple < 1:
   2777     raise ValueError('Invalid multiple %s, must be > 0.' % multiple)
   2778   with ops.name_scope(name, 'expand_and_tile',
   2779                       (tensor, multiple, dim)) as scope:
   2780     # Sparse.
   2781     tensor = sparse_tensor.convert_to_tensor_or_sparse_tensor(tensor)
   2782     if isinstance(tensor, sparse_tensor.SparseTensor):
   2783       if dim < 0:
   2784         expand_dims = array_ops.reshape(
   2785             array_ops.size(tensor.dense_shape) + dim, [1])
   2786       else:
   2787         expand_dims = [dim]
   2788       expanded_shape = array_ops.concat(
   2789           (array_ops.slice(tensor.dense_shape, [0], expand_dims), [1],
   2790            array_ops.slice(tensor.dense_shape, expand_dims, [-1])),
   2791           0,
   2792           name='expanded_shape')
   2793       expanded = sparse_ops.sparse_reshape(
   2794           tensor, shape=expanded_shape, name='expand')
   2795       if multiple == 1:
   2796         return expanded
   2797       return sparse_ops.sparse_concat(
   2798           dim - 1 if dim < 0 else dim, [expanded] * multiple, name=scope)
   2799 
   2800     # Dense.
   2801     expanded = array_ops.expand_dims(
   2802         tensor, dim if (dim >= 0) else (dim - 1), name='expand')
   2803     if multiple == 1:
   2804       return expanded
   2805     ones = array_ops.ones_like(array_ops.shape(tensor))
   2806     tile_multiples = array_ops.concat(
   2807         (ones[:dim], (multiple,), ones[dim:]), 0, name='multiples')
   2808     return array_ops.tile(expanded, tile_multiples, name=scope)
   2809 
   2810 
   2811 def _num_relevant(labels, k):
   2812   """Computes number of relevant values for each row in labels.
   2813 
   2814   For labels with shape [D1, ... DN, num_labels], this is the minimum of
   2815   `num_labels` and `k`.
   2816 
   2817   Args:
   2818     labels: `int64` `Tensor` or `SparseTensor` with shape
   2819       [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
   2820       target classes for the associated prediction. Commonly, N=1 and `labels`
   2821       has shape [batch_size, num_labels].
   2822     k: Integer, k for @k metric.
   2823 
   2824   Returns:
   2825     Integer `Tensor` of shape [D1, ... DN], where each value is the number of
   2826     relevant values for that row.
   2827 
   2828   Raises:
   2829     ValueError: if inputs have invalid dtypes or values.
   2830   """
   2831   if k < 1:
   2832     raise ValueError('Invalid k=%s.' % k)
   2833   with ops.name_scope(None, 'num_relevant', (labels,)) as scope:
   2834     # For SparseTensor, calculate separate count for each row.
   2835     labels = sparse_tensor.convert_to_tensor_or_sparse_tensor(labels)
   2836     if isinstance(labels, sparse_tensor.SparseTensor):
   2837       return math_ops.minimum(sets.set_size(labels), k, name=scope)
   2838 
   2839     # For dense Tensor, calculate scalar count based on last dimension, and
   2840     # tile across labels shape.
   2841     labels_shape = array_ops.shape(labels)
   2842     labels_size = labels_shape[-1]
   2843     num_relevant_scalar = math_ops.minimum(labels_size, k)
   2844     return array_ops.fill(labels_shape[0:-1], num_relevant_scalar, name=scope)
   2845 
   2846 
   2847 def _sparse_average_precision_at_top_k(labels, predictions_idx):
   2848   """Computes average precision@k of predictions with respect to sparse labels.
   2849 
   2850   From en.wikipedia.org/wiki/Information_retrieval#Average_precision, formula
   2851   for each row is:
   2852 
   2853     AveP = sum_{i=1...k} P_{i} * rel_{i} / num_relevant_items
   2854 
   2855   A "row" is the elements in dimension [D1, ... DN] of `predictions_idx`,
   2856   `labels`, and the result `Tensors`. In the common case, this is [batch_size].
   2857   Each row of the results contains the average precision for that row.
   2858 
   2859   Args:
   2860     labels: `int64` `Tensor` or `SparseTensor` with shape
   2861       [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
   2862       num_labels=1. N >= 1 and num_labels is the number of target classes for
   2863       the associated prediction. Commonly, N=1 and `labels` has shape
   2864       [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
   2865       Values should be in range [0, num_classes).
   2866     predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
   2867       Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
   2868       dimension must be set and contains the top `k` predicted class indices.
   2869       [D1, ... DN] must match `labels`. Values should be in range
   2870       [0, num_classes).
   2871 
   2872   Returns:
   2873     `float64` `Tensor` of shape [D1, ... DN], where each value is the average
   2874     precision for that row.
   2875 
   2876   Raises:
   2877     ValueError: if the last dimension of predictions_idx is not set.
   2878   """
   2879   with ops.name_scope(None, 'average_precision',
   2880                       (predictions_idx, labels)) as scope:
   2881     predictions_idx = math_ops.to_int64(predictions_idx, name='predictions_idx')
   2882     if predictions_idx.get_shape().ndims == 0:
   2883       raise ValueError('The rank of predictions_idx must be at least 1.')
   2884     k = predictions_idx.get_shape().as_list()[-1]
   2885     if k is None:
   2886       raise ValueError('The last dimension of predictions_idx must be set.')
   2887     labels = _maybe_expand_labels(labels, predictions_idx)
   2888 
   2889     # Expand dims to produce [D1, ... DN, k, 1] tensor. This gives us a separate
   2890     # prediction for each k, so we can calculate separate true positive values
   2891     # for each k.
   2892     predictions_idx_per_k = array_ops.expand_dims(
   2893         predictions_idx, -1, name='predictions_idx_per_k')
   2894 
   2895     # Replicate labels k times to produce [D1, ... DN, k, num_labels] tensor.
   2896     labels_per_k = _expand_and_tile(
   2897         labels, multiple=k, dim=-1, name='labels_per_k')
   2898 
   2899     # The following tensors are all of shape [D1, ... DN, k], containing values
   2900     # per row, per k value.
   2901     # `relevant_per_k` (int32) - Relevance indicator, 1 if the prediction at
   2902     #     that k value is correct, 0 otherwise. This is the "rel_{i}" term from
   2903     #     the formula above.
   2904     # `tp_per_k` (int32) - True positive counts.
   2905     # `retrieved_per_k` (int32) - Number of predicted values at each k. This is
   2906     #     the precision denominator.
   2907     # `precision_per_k` (float64) - Precision at each k. This is the "P_{i}"
   2908     #     term from the formula above.
   2909     # `relevant_precision_per_k` (float64) - Relevant precisions; i.e.,
   2910     #     precisions at all k for which relevance indicator is true.
   2911     relevant_per_k = _sparse_true_positive_at_k(
   2912         labels_per_k, predictions_idx_per_k, name='relevant_per_k')
   2913     tp_per_k = math_ops.cumsum(relevant_per_k, axis=-1, name='tp_per_k')
   2914     retrieved_per_k = math_ops.cumsum(
   2915         array_ops.ones_like(relevant_per_k), axis=-1, name='retrieved_per_k')
   2916     precision_per_k = math_ops.div(
   2917         math_ops.to_double(tp_per_k),
   2918         math_ops.to_double(retrieved_per_k),
   2919         name='precision_per_k')
   2920     relevant_precision_per_k = math_ops.multiply(
   2921         precision_per_k,
   2922         math_ops.to_double(relevant_per_k),
   2923         name='relevant_precision_per_k')
   2924 
   2925     # Reduce along k dimension to get the sum, yielding a [D1, ... DN] tensor.
   2926     precision_sum = math_ops.reduce_sum(
   2927         relevant_precision_per_k, reduction_indices=(-1,), name='precision_sum')
   2928 
   2929     # Divide by number of relevant items to get average precision. These are
   2930     # the "num_relevant_items" and "AveP" terms from the formula above.
   2931     num_relevant_items = math_ops.to_double(_num_relevant(labels, k))
   2932     return math_ops.div(precision_sum, num_relevant_items, name=scope)
   2933 
   2934 
   2935 def _streaming_sparse_average_precision_at_top_k(labels,
   2936                                                  predictions_idx,
   2937                                                  weights=None,
   2938                                                  metrics_collections=None,
   2939                                                  updates_collections=None,
   2940                                                  name=None):
   2941   """Computes average precision@k of predictions with respect to sparse labels.
   2942 
   2943   `sparse_average_precision_at_top_k` creates two local variables,
   2944   `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
   2945   are used to compute the frequency. This frequency is ultimately returned as
   2946   `average_precision_at_<k>`: an idempotent operation that simply divides
   2947   `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
   2948 
   2949   For estimation of the metric over a stream of data, the function creates an
   2950   `update_op` operation that updates these variables and returns the
   2951   `precision_at_<k>`. Set operations applied to `top_k` and `labels` calculate
   2952   the true positives and false positives weighted by `weights`. Then `update_op`
   2953   increments `true_positive_at_<k>` and `false_positive_at_<k>` using these
   2954   values.
   2955 
   2956   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   2957 
   2958   Args:
   2959     labels: `int64` `Tensor` or `SparseTensor` with shape
   2960       [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
   2961       num_labels=1. N >= 1 and num_labels is the number of target classes for
   2962       the associated prediction. Commonly, N=1 and `labels` has shape
   2963       [batch_size, num_labels]. [D1, ... DN] must match `predictions_idx`.
   2964       Values should be in range [0, num_classes).
   2965     predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where N >= 1.
   2966       Commonly, N=1 and `predictions_idx` has shape [batch size, k]. The final
   2967       dimension contains the top `k` predicted class indices. [D1, ... DN] must
   2968       match `labels`. Values should be in range [0, num_classes).
   2969     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   2970       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   2971       dimensions must be either `1`, or the same as the corresponding `labels`
   2972       dimension).
   2973     metrics_collections: An optional list of collections that values should
   2974       be added to.
   2975     updates_collections: An optional list of collections that updates should
   2976       be added to.
   2977     name: Name of new update operation, and namespace for other dependent ops.
   2978 
   2979   Returns:
   2980     mean_average_precision: Scalar `float64` `Tensor` with the mean average
   2981       precision values.
   2982     update: `Operation` that increments variables appropriately, and whose
   2983       value matches `metric`.
   2984   """
   2985   with ops.name_scope(name, 'average_precision_at_top_k',
   2986                       (predictions_idx, labels, weights)) as scope:
   2987     # Calculate per-example average precision, and apply weights.
   2988     average_precision = _sparse_average_precision_at_top_k(
   2989         predictions_idx=predictions_idx, labels=labels)
   2990     if weights is not None:
   2991       weights = weights_broadcast_ops.broadcast_weights(
   2992           math_ops.to_double(weights), average_precision)
   2993       average_precision = math_ops.multiply(average_precision, weights)
   2994 
   2995     # Create accumulation variables and update ops for max average precision and
   2996     # total average precision.
   2997     with ops.name_scope(None, 'max', (average_precision,)) as max_scope:
   2998       # `max` is the max possible precision. Since max for any row is 1.0:
   2999       # - For the unweighted case, this is just the number of rows.
   3000       # - For the weighted case, it's the sum of the weights broadcast across
   3001       #   `average_precision` rows.
   3002       max_var = metric_variable([], dtypes.float64, name=max_scope)
   3003       if weights is None:
   3004         batch_max = math_ops.to_double(
   3005             array_ops.size(average_precision, name='batch_max'))
   3006       else:
   3007         batch_max = math_ops.reduce_sum(weights, name='batch_max')
   3008       max_update = state_ops.assign_add(max_var, batch_max, name='update')
   3009     with ops.name_scope(None, 'total', (average_precision,)) as total_scope:
   3010       total_var = metric_variable([], dtypes.float64, name=total_scope)
   3011       batch_total = math_ops.reduce_sum(average_precision, name='batch_total')
   3012       total_update = state_ops.assign_add(total_var, batch_total, name='update')
   3013 
   3014     # Divide total by max to get mean, for both vars and the update ops.
   3015     mean_average_precision = _safe_scalar_div(total_var, max_var, name='mean')
   3016     update = _safe_scalar_div(total_update, max_update, name=scope)
   3017 
   3018     if metrics_collections:
   3019       ops.add_to_collections(metrics_collections, mean_average_precision)
   3020     if updates_collections:
   3021       ops.add_to_collections(updates_collections, update)
   3022 
   3023     return mean_average_precision, update
   3024 
   3025 
   3026 @tf_export('metrics.sparse_average_precision_at_k')
   3027 @deprecated(None, 'Use average_precision_at_k instead')
   3028 def sparse_average_precision_at_k(labels,
   3029                                   predictions,
   3030                                   k,
   3031                                   weights=None,
   3032                                   metrics_collections=None,
   3033                                   updates_collections=None,
   3034                                   name=None):
   3035   """Renamed to `average_precision_at_k`, please use that method instead."""
   3036   return average_precision_at_k(
   3037       labels=labels,
   3038       predictions=predictions,
   3039       k=k,
   3040       weights=weights,
   3041       metrics_collections=metrics_collections,
   3042       updates_collections=updates_collections,
   3043       name=name)
   3044 
   3045 
   3046 @tf_export('metrics.average_precision_at_k')
   3047 def average_precision_at_k(labels,
   3048                            predictions,
   3049                            k,
   3050                            weights=None,
   3051                            metrics_collections=None,
   3052                            updates_collections=None,
   3053                            name=None):
   3054   """Computes average precision@k of predictions with respect to sparse labels.
   3055 
   3056   `average_precision_at_k` creates two local variables,
   3057   `average_precision_at_<k>/total` and `average_precision_at_<k>/max`, that
   3058   are used to compute the frequency. This frequency is ultimately returned as
   3059   `average_precision_at_<k>`: an idempotent operation that simply divides
   3060   `average_precision_at_<k>/total` by `average_precision_at_<k>/max`.
   3061 
   3062   For estimation of the metric over a stream of data, the function creates an
   3063   `update_op` operation that updates these variables and returns the
   3064   `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
   3065   indicating the top `k` `predictions`. Set operations applied to `top_k` and
   3066   `labels` calculate the true positives and false positives weighted by
   3067   `weights`. Then `update_op` increments `true_positive_at_<k>` and
   3068   `false_positive_at_<k>` using these values.
   3069 
   3070   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   3071 
   3072   Args:
   3073     labels: `int64` `Tensor` or `SparseTensor` with shape
   3074       [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
   3075       num_labels=1. N >= 1 and num_labels is the number of target classes for
   3076       the associated prediction. Commonly, N=1 and `labels` has shape
   3077       [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
   3078       should be in range [0, num_classes), where num_classes is the last
   3079       dimension of `predictions`. Values outside this range are ignored.
   3080     predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
   3081       N >= 1. Commonly, N=1 and `predictions` has shape
   3082       [batch size, num_classes]. The final dimension contains the logit values
   3083       for each class. [D1, ... DN] must match `labels`.
   3084     k: Integer, k for @k metric. This will calculate an average precision for
   3085       range `[1,k]`, as documented above.
   3086     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   3087       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   3088       dimensions must be either `1`, or the same as the corresponding `labels`
   3089       dimension).
   3090     metrics_collections: An optional list of collections that values should
   3091       be added to.
   3092     updates_collections: An optional list of collections that updates should
   3093       be added to.
   3094     name: Name of new update operation, and namespace for other dependent ops.
   3095 
   3096   Returns:
   3097     mean_average_precision: Scalar `float64` `Tensor` with the mean average
   3098       precision values.
   3099     update: `Operation` that increments variables appropriately, and whose
   3100       value matches `metric`.
   3101 
   3102   Raises:
   3103     ValueError: if k is invalid.
   3104     RuntimeError: If eager execution is enabled.
   3105   """
   3106   if context.in_eager_mode():
   3107     raise RuntimeError('tf.metrics.sparse_average_precision_at_k is not '
   3108                        'supported when eager execution is enabled.')
   3109 
   3110   if k < 1:
   3111     raise ValueError('Invalid k=%s.' % k)
   3112   with ops.name_scope(name, _at_k_name('average_precision', k),
   3113                       (predictions, labels, weights)) as scope:
   3114     # Calculate top k indices to produce [D1, ... DN, k] tensor.
   3115     _, predictions_idx = nn.top_k(predictions, k)
   3116     return _streaming_sparse_average_precision_at_top_k(
   3117         labels=labels,
   3118         predictions_idx=predictions_idx,
   3119         weights=weights,
   3120         metrics_collections=metrics_collections,
   3121         updates_collections=updates_collections,
   3122         name=scope)
   3123 
   3124 
   3125 def _sparse_false_positive_at_k(labels,
   3126                                 predictions_idx,
   3127                                 class_id=None,
   3128                                 weights=None):
   3129   """Calculates false positives for precision@k.
   3130 
   3131   If `class_id` is specified, calculate binary true positives for `class_id`
   3132       only.
   3133   If `class_id` is not specified, calculate metrics for `k` predicted vs
   3134       `n` label classes, where `n` is the 2nd dimension of `labels_sparse`.
   3135 
   3136   Args:
   3137     labels: `int64` `Tensor` or `SparseTensor` with shape
   3138       [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
   3139       target classes for the associated prediction. Commonly, N=1 and `labels`
   3140       has shape [batch_size, num_labels]. [D1, ... DN] must match
   3141       `predictions_idx`.
   3142     predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
   3143       top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
   3144       match `labels`.
   3145     class_id: Class for which we want binary metrics.
   3146     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   3147       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   3148       dimensions must be either `1`, or the same as the corresponding `labels`
   3149       dimension).
   3150 
   3151   Returns:
   3152     A [D1, ... DN] `Tensor` of false positive counts.
   3153   """
   3154   with ops.name_scope(None, 'false_positives',
   3155                       (predictions_idx, labels, weights)):
   3156     labels, predictions_idx = _maybe_select_class_id(labels, predictions_idx,
   3157                                                      class_id)
   3158     fp = sets.set_size(
   3159         sets.set_difference(predictions_idx, labels, aminusb=True))
   3160     fp = math_ops.to_double(fp)
   3161     if weights is not None:
   3162       with ops.control_dependencies((weights_broadcast_ops.assert_broadcastable(
   3163           weights, fp),)):
   3164         weights = math_ops.to_double(weights)
   3165         fp = math_ops.multiply(fp, weights)
   3166     return fp
   3167 
   3168 
   3169 def _streaming_sparse_false_positive_at_k(labels,
   3170                                           predictions_idx,
   3171                                           k=None,
   3172                                           class_id=None,
   3173                                           weights=None,
   3174                                           name=None):
   3175   """Calculates weighted per step false positives for precision@k.
   3176 
   3177   If `class_id` is specified, calculate binary true positives for `class_id`
   3178       only.
   3179   If `class_id` is not specified, calculate metrics for `k` predicted vs
   3180       `n` label classes, where `n` is the 2nd dimension of `labels`.
   3181 
   3182   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   3183 
   3184   Args:
   3185     labels: `int64` `Tensor` or `SparseTensor` with shape
   3186       [D1, ... DN, num_labels], where N >= 1 and num_labels is the number of
   3187       target classes for the associated prediction. Commonly, N=1 and `labels`
   3188       has shape [batch_size, num_labels]. [D1, ... DN] must match
   3189       `predictions_idx`.
   3190     predictions_idx: 1-D or higher `int64` `Tensor` with last dimension `k`,
   3191       top `k` predicted classes. For rank `n`, the first `n-1` dimensions must
   3192       match `labels`.
   3193     k: Integer, k for @k metric. This is only used for default op name.
   3194     class_id: Class for which we want binary metrics.
   3195     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   3196       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   3197       dimensions must be either `1`, or the same as the corresponding `labels`
   3198       dimension).
   3199     name: Name of new variable, and namespace for other dependent ops.
   3200 
   3201   Returns:
   3202     A tuple of `Variable` and update `Operation`.
   3203 
   3204   Raises:
   3205     ValueError: If `weights` is not `None` and has an incompatible shape.
   3206   """
   3207   with ops.name_scope(name, _at_k_name('false_positive', k, class_id=class_id),
   3208                       (predictions_idx, labels, weights)) as scope:
   3209     fp = _sparse_false_positive_at_k(
   3210         predictions_idx=predictions_idx,
   3211         labels=labels,
   3212         class_id=class_id,
   3213         weights=weights)
   3214     batch_total_fp = math_ops.to_double(math_ops.reduce_sum(fp))
   3215 
   3216     var = metric_variable([], dtypes.float64, name=scope)
   3217     return var, state_ops.assign_add(var, batch_total_fp, name='update')
   3218 
   3219 
   3220 @tf_export('metrics.precision_at_top_k')
   3221 def precision_at_top_k(labels,
   3222                        predictions_idx,
   3223                        k=None,
   3224                        class_id=None,
   3225                        weights=None,
   3226                        metrics_collections=None,
   3227                        updates_collections=None,
   3228                        name=None):
   3229   """Computes precision@k of the predictions with respect to sparse labels.
   3230 
   3231   Differs from `sparse_precision_at_k` in that predictions must be in the form
   3232   of top `k` class indices, whereas `sparse_precision_at_k` expects logits.
   3233   Refer to `sparse_precision_at_k` for more details.
   3234 
   3235   Args:
   3236     labels: `int64` `Tensor` or `SparseTensor` with shape
   3237       [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
   3238       num_labels=1. N >= 1 and num_labels is the number of target classes for
   3239       the associated prediction. Commonly, N=1 and `labels` has shape
   3240       [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
   3241       should be in range [0, num_classes), where num_classes is the last
   3242       dimension of `predictions`. Values outside this range are ignored.
   3243     predictions_idx: Integer `Tensor` with shape [D1, ... DN, k] where
   3244       N >= 1. Commonly, N=1 and predictions has shape [batch size, k].
   3245       The final dimension contains the top `k` predicted class indices.
   3246       [D1, ... DN] must match `labels`.
   3247     k: Integer, k for @k metric. Only used for the default op name.
   3248     class_id: Integer class ID for which we want binary metrics. This should be
   3249       in range [0, num_classes], where num_classes is the last dimension of
   3250       `predictions`. If `class_id` is outside this range, the method returns
   3251       NAN.
   3252     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   3253       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   3254       dimensions must be either `1`, or the same as the corresponding `labels`
   3255       dimension).
   3256     metrics_collections: An optional list of collections that values should
   3257       be added to.
   3258     updates_collections: An optional list of collections that updates should
   3259       be added to.
   3260     name: Name of new update operation, and namespace for other dependent ops.
   3261 
   3262   Returns:
   3263     precision: Scalar `float64` `Tensor` with the value of `true_positives`
   3264       divided by the sum of `true_positives` and `false_positives`.
   3265     update_op: `Operation` that increments `true_positives` and
   3266       `false_positives` variables appropriately, and whose value matches
   3267       `precision`.
   3268 
   3269   Raises:
   3270     ValueError: If `weights` is not `None` and its shape doesn't match
   3271       `predictions`, or if either `metrics_collections` or `updates_collections`
   3272       are not a list or tuple.
   3273     RuntimeError: If eager execution is enabled.
   3274   """
   3275   if context.in_eager_mode():
   3276     raise RuntimeError('tf.metrics.precision_at_top_k is not '
   3277                        'supported when eager execution is enabled.')
   3278 
   3279   with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
   3280                       (predictions_idx, labels, weights)) as scope:
   3281     labels = _maybe_expand_labels(labels, predictions_idx)
   3282     top_k_idx = math_ops.to_int64(predictions_idx)
   3283     tp, tp_update = _streaming_sparse_true_positive_at_k(
   3284         predictions_idx=top_k_idx,
   3285         labels=labels,
   3286         k=k,
   3287         class_id=class_id,
   3288         weights=weights)
   3289     fp, fp_update = _streaming_sparse_false_positive_at_k(
   3290         predictions_idx=top_k_idx,
   3291         labels=labels,
   3292         k=k,
   3293         class_id=class_id,
   3294         weights=weights)
   3295 
   3296     metric = math_ops.div(tp, math_ops.add(tp, fp), name=scope)
   3297     update = math_ops.div(
   3298         tp_update, math_ops.add(tp_update, fp_update), name='update')
   3299     if metrics_collections:
   3300       ops.add_to_collections(metrics_collections, metric)
   3301     if updates_collections:
   3302       ops.add_to_collections(updates_collections, update)
   3303     return metric, update
   3304 
   3305 
   3306 @tf_export('metrics.sparse_precision_at_k')
   3307 @deprecated(None, 'Use precision_at_k instead')
   3308 def sparse_precision_at_k(labels,
   3309                           predictions,
   3310                           k,
   3311                           class_id=None,
   3312                           weights=None,
   3313                           metrics_collections=None,
   3314                           updates_collections=None,
   3315                           name=None):
   3316   """Renamed to `precision_at_k`, please use that method instead."""
   3317   return precision_at_k(
   3318       labels=labels,
   3319       predictions=predictions,
   3320       k=k,
   3321       class_id=class_id,
   3322       weights=weights,
   3323       metrics_collections=metrics_collections,
   3324       updates_collections=updates_collections,
   3325       name=name)
   3326 
   3327 
   3328 @tf_export('metrics.precision_at_k')
   3329 def precision_at_k(labels,
   3330                    predictions,
   3331                    k,
   3332                    class_id=None,
   3333                    weights=None,
   3334                    metrics_collections=None,
   3335                    updates_collections=None,
   3336                    name=None):
   3337   """Computes precision@k of the predictions with respect to sparse labels.
   3338 
   3339   If `class_id` is specified, we calculate precision by considering only the
   3340       entries in the batch for which `class_id` is in the top-k highest
   3341       `predictions`, and computing the fraction of them for which `class_id` is
   3342       indeed a correct label.
   3343   If `class_id` is not specified, we'll calculate precision as how often on
   3344       average a class among the top-k classes with the highest predicted values
   3345       of a batch entry is correct and can be found in the label for that entry.
   3346 
   3347   `precision_at_k` creates two local variables,
   3348   `true_positive_at_<k>` and `false_positive_at_<k>`, that are used to compute
   3349   the precision@k frequency. This frequency is ultimately returned as
   3350   `precision_at_<k>`: an idempotent operation that simply divides
   3351   `true_positive_at_<k>` by total (`true_positive_at_<k>` +
   3352   `false_positive_at_<k>`).
   3353 
   3354   For estimation of the metric over a stream of data, the function creates an
   3355   `update_op` operation that updates these variables and returns the
   3356   `precision_at_<k>`. Internally, a `top_k` operation computes a `Tensor`
   3357   indicating the top `k` `predictions`. Set operations applied to `top_k` and
   3358   `labels` calculate the true positives and false positives weighted by
   3359   `weights`. Then `update_op` increments `true_positive_at_<k>` and
   3360   `false_positive_at_<k>` using these values.
   3361 
   3362   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   3363 
   3364   Args:
   3365     labels: `int64` `Tensor` or `SparseTensor` with shape
   3366       [D1, ... DN, num_labels] or [D1, ... DN], where the latter implies
   3367       num_labels=1. N >= 1 and num_labels is the number of target classes for
   3368       the associated prediction. Commonly, N=1 and `labels` has shape
   3369       [batch_size, num_labels]. [D1, ... DN] must match `predictions`. Values
   3370       should be in range [0, num_classes), where num_classes is the last
   3371       dimension of `predictions`. Values outside this range are ignored.
   3372     predictions: Float `Tensor` with shape [D1, ... DN, num_classes] where
   3373       N >= 1. Commonly, N=1 and predictions has shape [batch size, num_classes].
   3374       The final dimension contains the logit values for each class. [D1, ... DN]
   3375       must match `labels`.
   3376     k: Integer, k for @k metric.
   3377     class_id: Integer class ID for which we want binary metrics. This should be
   3378       in range [0, num_classes], where num_classes is the last dimension of
   3379       `predictions`. If `class_id` is outside this range, the method returns
   3380       NAN.
   3381     weights: `Tensor` whose rank is either 0, or n-1, where n is the rank of
   3382       `labels`. If the latter, it must be broadcastable to `labels` (i.e., all
   3383       dimensions must be either `1`, or the same as the corresponding `labels`
   3384       dimension).
   3385     metrics_collections: An optional list of collections that values should
   3386       be added to.
   3387     updates_collections: An optional list of collections that updates should
   3388       be added to.
   3389     name: Name of new update operation, and namespace for other dependent ops.
   3390 
   3391   Returns:
   3392     precision: Scalar `float64` `Tensor` with the value of `true_positives`
   3393       divided by the sum of `true_positives` and `false_positives`.
   3394     update_op: `Operation` that increments `true_positives` and
   3395       `false_positives` variables appropriately, and whose value matches
   3396       `precision`.
   3397 
   3398   Raises:
   3399     ValueError: If `weights` is not `None` and its shape doesn't match
   3400       `predictions`, or if either `metrics_collections` or `updates_collections`
   3401       are not a list or tuple.
   3402     RuntimeError: If eager execution is enabled.
   3403   """
   3404   if context.in_eager_mode():
   3405     raise RuntimeError('tf.metrics.sparse_precision_at_k is not '
   3406                        'supported when eager execution is enabled.')
   3407 
   3408   with ops.name_scope(name, _at_k_name('precision', k, class_id=class_id),
   3409                       (predictions, labels, weights)) as scope:
   3410     _, top_k_idx = nn.top_k(predictions, k)
   3411     return precision_at_top_k(
   3412         labels=labels,
   3413         predictions_idx=top_k_idx,
   3414         k=k,
   3415         class_id=class_id,
   3416         weights=weights,
   3417         metrics_collections=metrics_collections,
   3418         updates_collections=updates_collections,
   3419         name=scope)
   3420 
   3421 
   3422 @tf_export('metrics.specificity_at_sensitivity')
   3423 def specificity_at_sensitivity(labels,
   3424                                predictions,
   3425                                sensitivity,
   3426                                weights=None,
   3427                                num_thresholds=200,
   3428                                metrics_collections=None,
   3429                                updates_collections=None,
   3430                                name=None):
   3431   """Computes the specificity at a given sensitivity.
   3432 
   3433   The `specificity_at_sensitivity` function creates four local
   3434   variables, `true_positives`, `true_negatives`, `false_positives` and
   3435   `false_negatives` that are used to compute the specificity at the given
   3436   sensitivity value. The threshold for the given sensitivity value is computed
   3437   and used to evaluate the corresponding specificity.
   3438 
   3439   For estimation of the metric over a stream of data, the function creates an
   3440   `update_op` operation that updates these variables and returns the
   3441   `specificity`. `update_op` increments the `true_positives`, `true_negatives`,
   3442   `false_positives` and `false_negatives` counts with the weight of each case
   3443   found in the `predictions` and `labels`.
   3444 
   3445   If `weights` is `None`, weights default to 1. Use weights of 0 to mask values.
   3446 
   3447   For additional information about specificity and sensitivity, see the
   3448   following: https://en.wikipedia.org/wiki/Sensitivity_and_specificity
   3449 
   3450   Args:
   3451     labels: The ground truth values, a `Tensor` whose dimensions must match
   3452       `predictions`. Will be cast to `bool`.
   3453     predictions: A floating point `Tensor` of arbitrary shape and whose values
   3454       are in the range `[0, 1]`.
   3455     sensitivity: A scalar value in range `[0, 1]`.
   3456     weights: Optional `Tensor` whose rank is either 0, or the same rank as
   3457       `labels`, and must be broadcastable to `labels` (i.e., all dimensions must
   3458       be either `1`, or the same as the corresponding `labels` dimension).
   3459     num_thresholds: The number of thresholds to use for matching the given
   3460       sensitivity.
   3461     metrics_collections: An optional list of collections that `specificity`
   3462       should be added to.
   3463     updates_collections: An optional list of collections that `update_op` should
   3464       be added to.
   3465     name: An optional variable_scope name.
   3466 
   3467   Returns:
   3468     specificity: A scalar `Tensor` representing the specificity at the given
   3469       `specificity` value.
   3470     update_op: An operation that increments the `true_positives`,
   3471       `true_negatives`, `false_positives` and `false_negatives` variables
   3472       appropriately and whose value matches `specificity`.
   3473 
   3474   Raises:
   3475     ValueError: If `predictions` and `labels` have mismatched shapes, if
   3476       `weights` is not `None` and its shape doesn't match `predictions`, or if
   3477       `sensitivity` is not between 0 and 1, or if either `metrics_collections`
   3478       or `updates_collections` are not a list or tuple.
   3479     RuntimeError: If eager execution is enabled.
   3480   """
   3481   if context.in_eager_mode():
   3482     raise RuntimeError('tf.metrics.specificity_at_sensitivity is not '
   3483                        'supported when eager execution is enabled.')
   3484 
   3485   if sensitivity < 0 or sensitivity > 1:
   3486     raise ValueError('`sensitivity` must be in the range [0, 1].')
   3487 
   3488   with variable_scope.variable_scope(name, 'specificity_at_sensitivity',
   3489                                      (predictions, labels, weights)):
   3490     kepsilon = 1e-7  # to account for floating point imprecisions
   3491     thresholds = [
   3492         (i + 1) * 1.0 / (num_thresholds - 1) for i in range(num_thresholds - 2)
   3493     ]
   3494     thresholds = [0.0 - kepsilon] + thresholds + [1.0 - kepsilon]
   3495 
   3496     values, update_ops = _confusion_matrix_at_thresholds(
   3497         labels, predictions, thresholds, weights)
   3498 
   3499     def compute_specificity_at_sensitivity(tp, tn, fp, fn, name):
   3500       """Computes the specificity at the given sensitivity.
   3501 
   3502       Args:
   3503         tp: True positives.
   3504         tn: True negatives.
   3505         fp: False positives.
   3506         fn: False negatives.
   3507         name: The name of the operation.
   3508 
   3509       Returns:
   3510         The specificity using the aggregated values.
   3511       """
   3512       sensitivities = math_ops.div(tp, tp + fn + kepsilon)
   3513 
   3514       # We'll need to use this trick until tf.argmax allows us to specify
   3515       # whether we should use the first or last index in case of ties.
   3516       min_val = math_ops.reduce_min(math_ops.abs(sensitivities - sensitivity))
   3517       indices_at_minval = math_ops.equal(
   3518           math_ops.abs(sensitivities - sensitivity), min_val)
   3519       indices_at_minval = math_ops.to_int64(indices_at_minval)
   3520       indices_at_minval = math_ops.cumsum(indices_at_minval)
   3521       tf_index = math_ops.argmax(indices_at_minval, 0)
   3522       tf_index = math_ops.cast(tf_index, dtypes.int32)
   3523 
   3524       # Now, we have the implicit threshold, so compute the specificity:
   3525       return math_ops.div(tn[tf_index], tn[tf_index] + fp[tf_index] + kepsilon,
   3526                           name)
   3527 
   3528     specificity = compute_specificity_at_sensitivity(
   3529         values['tp'], values['tn'], values['fp'], values['fn'], 'value')
   3530     update_op = compute_specificity_at_sensitivity(
   3531         update_ops['tp'], update_ops['tn'], update_ops['fp'], update_ops['fn'],
   3532         'update_op')
   3533 
   3534     if metrics_collections:
   3535       ops.add_to_collections(metrics_collections, specificity)
   3536 
   3537     if updates_collections:
   3538       ops.add_to_collections(updates_collections, update_op)
   3539 
   3540     return specificity, update_op
   3541