Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for metrics."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import functools
     22 import math
     23 
     24 import numpy as np
     25 from six.moves import xrange  # pylint: disable=redefined-builtin
     26 
     27 from tensorflow.python.framework import constant_op
     28 from tensorflow.python.framework import dtypes as dtypes_lib
     29 from tensorflow.python.framework import errors_impl
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.framework import sparse_tensor
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import data_flow_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.ops import metrics
     36 from tensorflow.python.ops import random_ops
     37 from tensorflow.python.ops import variables
     38 import tensorflow.python.ops.data_flow_grad  # pylint: disable=unused-import
     39 import tensorflow.python.ops.nn_grad  # pylint: disable=unused-import
     40 from tensorflow.python.platform import test
     41 
     42 NAN = float('nan')
     43 
     44 
     45 def _enqueue_vector(sess, queue, values, shape=None):
     46   if not shape:
     47     shape = (1, len(values))
     48   dtype = queue.dtypes[0]
     49   sess.run(
     50       queue.enqueue(constant_op.constant(
     51           values, dtype=dtype, shape=shape)))
     52 
     53 
     54 def _binary_2d_label_to_2d_sparse_value(labels):
     55   """Convert dense 2D binary indicator to sparse ID.
     56 
     57   Only 1 values in `labels` are included in result.
     58 
     59   Args:
     60     labels: Dense 2D binary indicator, shape [batch_size, num_classes].
     61 
     62   Returns:
     63     `SparseTensorValue` of shape [batch_size, num_classes], where num_classes
     64     is the number of `1` values in each row of `labels`. Values are indices
     65     of `1` values along the last dimension of `labels`.
     66   """
     67   indices = []
     68   values = []
     69   batch = 0
     70   for row in labels:
     71     label = 0
     72     xi = 0
     73     for x in row:
     74       if x == 1:
     75         indices.append([batch, xi])
     76         values.append(label)
     77         xi += 1
     78       else:
     79         assert x == 0
     80       label += 1
     81     batch += 1
     82   shape = [len(labels), len(labels[0])]
     83   return sparse_tensor.SparseTensorValue(
     84       np.array(indices, np.int64),
     85       np.array(values, np.int64), np.array(shape, np.int64))
     86 
     87 
     88 def _binary_2d_label_to_1d_sparse_value(labels):
     89   """Convert dense 2D binary indicator to sparse ID.
     90 
     91   Only 1 values in `labels` are included in result.
     92 
     93   Args:
     94     labels: Dense 2D binary indicator, shape [batch_size, num_classes]. Each
     95     row must contain exactly 1 `1` value.
     96 
     97   Returns:
     98     `SparseTensorValue` of shape [batch_size]. Values are indices of `1` values
     99     along the last dimension of `labels`.
    100 
    101   Raises:
    102     ValueError: if there is not exactly 1 `1` value per row of `labels`.
    103   """
    104   indices = []
    105   values = []
    106   batch = 0
    107   for row in labels:
    108     label = 0
    109     xi = 0
    110     for x in row:
    111       if x == 1:
    112         indices.append([batch])
    113         values.append(label)
    114         xi += 1
    115       else:
    116         assert x == 0
    117       label += 1
    118     batch += 1
    119   if indices != [[i] for i in range(len(labels))]:
    120     raise ValueError('Expected 1 label/example, got %s.' % indices)
    121   shape = [len(labels)]
    122   return sparse_tensor.SparseTensorValue(
    123       np.array(indices, np.int64),
    124       np.array(values, np.int64), np.array(shape, np.int64))
    125 
    126 
    127 def _binary_3d_label_to_sparse_value(labels):
    128   """Convert dense 3D binary indicator tensor to sparse tensor.
    129 
    130   Only 1 values in `labels` are included in result.
    131 
    132   Args:
    133     labels: Dense 2D binary indicator tensor.
    134 
    135   Returns:
    136     `SparseTensorValue` whose values are indices along the last dimension of
    137     `labels`.
    138   """
    139   indices = []
    140   values = []
    141   for d0, labels_d0 in enumerate(labels):
    142     for d1, labels_d1 in enumerate(labels_d0):
    143       d2 = 0
    144       for class_id, label in enumerate(labels_d1):
    145         if label == 1:
    146           values.append(class_id)
    147           indices.append([d0, d1, d2])
    148           d2 += 1
    149         else:
    150           assert label == 0
    151   shape = [len(labels), len(labels[0]), len(labels[0][0])]
    152   return sparse_tensor.SparseTensorValue(
    153       np.array(indices, np.int64),
    154       np.array(values, np.int64), np.array(shape, np.int64))
    155 
    156 
    157 def _assert_nan(test_case, actual):
    158   test_case.assertTrue(math.isnan(actual), 'Expected NAN, got %s.' % actual)
    159 
    160 
    161 def _assert_metric_variables(test_case, expected):
    162   test_case.assertEquals(
    163       set(expected), set(v.name for v in variables.local_variables()))
    164   test_case.assertEquals(
    165       set(expected),
    166       set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES)))
    167 
    168 
    169 def _test_values(shape):
    170   return np.reshape(np.cumsum(np.ones(shape)), newshape=shape)
    171 
    172 
    173 class MeanTest(test.TestCase):
    174 
    175   def setUp(self):
    176     ops.reset_default_graph()
    177 
    178   def testVars(self):
    179     metrics.mean(array_ops.ones([4, 3]))
    180     _assert_metric_variables(self, ('mean/count:0', 'mean/total:0'))
    181 
    182   def testMetricsCollection(self):
    183     my_collection_name = '__metrics__'
    184     mean, _ = metrics.mean(
    185         array_ops.ones([4, 3]), metrics_collections=[my_collection_name])
    186     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
    187 
    188   def testUpdatesCollection(self):
    189     my_collection_name = '__updates__'
    190     _, update_op = metrics.mean(
    191         array_ops.ones([4, 3]), updates_collections=[my_collection_name])
    192     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
    193 
    194   def testBasic(self):
    195     with self.test_session() as sess:
    196       values_queue = data_flow_ops.FIFOQueue(
    197           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    198       _enqueue_vector(sess, values_queue, [0, 1])
    199       _enqueue_vector(sess, values_queue, [-4.2, 9.1])
    200       _enqueue_vector(sess, values_queue, [6.5, 0])
    201       _enqueue_vector(sess, values_queue, [-3.2, 4.0])
    202       values = values_queue.dequeue()
    203 
    204       mean, update_op = metrics.mean(values)
    205 
    206       sess.run(variables.local_variables_initializer())
    207       for _ in range(4):
    208         sess.run(update_op)
    209       self.assertAlmostEqual(1.65, sess.run(mean), 5)
    210 
    211   def testUpdateOpsReturnsCurrentValue(self):
    212     with self.test_session() as sess:
    213       values_queue = data_flow_ops.FIFOQueue(
    214           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    215       _enqueue_vector(sess, values_queue, [0, 1])
    216       _enqueue_vector(sess, values_queue, [-4.2, 9.1])
    217       _enqueue_vector(sess, values_queue, [6.5, 0])
    218       _enqueue_vector(sess, values_queue, [-3.2, 4.0])
    219       values = values_queue.dequeue()
    220 
    221       mean, update_op = metrics.mean(values)
    222 
    223       sess.run(variables.local_variables_initializer())
    224 
    225       self.assertAlmostEqual(0.5, sess.run(update_op), 5)
    226       self.assertAlmostEqual(1.475, sess.run(update_op), 5)
    227       self.assertAlmostEqual(12.4 / 6.0, sess.run(update_op), 5)
    228       self.assertAlmostEqual(1.65, sess.run(update_op), 5)
    229 
    230       self.assertAlmostEqual(1.65, sess.run(mean), 5)
    231 
    232   def testUnweighted(self):
    233     values = _test_values((3, 2, 4, 1))
    234     mean_results = (
    235         metrics.mean(values),
    236         metrics.mean(values, weights=1.0),
    237         metrics.mean(values, weights=np.ones((1, 1, 1))),
    238         metrics.mean(values, weights=np.ones((1, 1, 1, 1))),
    239         metrics.mean(values, weights=np.ones((1, 1, 1, 1, 1))),
    240         metrics.mean(values, weights=np.ones((1, 1, 4))),
    241         metrics.mean(values, weights=np.ones((1, 1, 4, 1))),
    242         metrics.mean(values, weights=np.ones((1, 2, 1))),
    243         metrics.mean(values, weights=np.ones((1, 2, 1, 1))),
    244         metrics.mean(values, weights=np.ones((1, 2, 4))),
    245         metrics.mean(values, weights=np.ones((1, 2, 4, 1))),
    246         metrics.mean(values, weights=np.ones((3, 1, 1))),
    247         metrics.mean(values, weights=np.ones((3, 1, 1, 1))),
    248         metrics.mean(values, weights=np.ones((3, 1, 4))),
    249         metrics.mean(values, weights=np.ones((3, 1, 4, 1))),
    250         metrics.mean(values, weights=np.ones((3, 2, 1))),
    251         metrics.mean(values, weights=np.ones((3, 2, 1, 1))),
    252         metrics.mean(values, weights=np.ones((3, 2, 4))),
    253         metrics.mean(values, weights=np.ones((3, 2, 4, 1))),
    254         metrics.mean(values, weights=np.ones((3, 2, 4, 1, 1))),)
    255     expected = np.mean(values)
    256     with self.test_session():
    257       variables.local_variables_initializer().run()
    258       for mean_result in mean_results:
    259         mean, update_op = mean_result
    260         self.assertAlmostEqual(expected, update_op.eval())
    261         self.assertAlmostEqual(expected, mean.eval())
    262 
    263   def _test_3d_weighted(self, values, weights):
    264     expected = (
    265         np.sum(np.multiply(weights, values)) /
    266         np.sum(np.multiply(weights, np.ones_like(values)))
    267     )
    268     mean, update_op = metrics.mean(values, weights=weights)
    269     with self.test_session():
    270       variables.local_variables_initializer().run()
    271       self.assertAlmostEqual(expected, update_op.eval(), places=5)
    272       self.assertAlmostEqual(expected, mean.eval(), places=5)
    273 
    274   def test1x1x1Weighted(self):
    275     self._test_3d_weighted(
    276         _test_values((3, 2, 4)),
    277         weights=np.asarray((5,)).reshape((1, 1, 1)))
    278 
    279   def test1x1xNWeighted(self):
    280     self._test_3d_weighted(
    281         _test_values((3, 2, 4)),
    282         weights=np.asarray((5, 7, 11, 3)).reshape((1, 1, 4)))
    283 
    284   def test1xNx1Weighted(self):
    285     self._test_3d_weighted(
    286         _test_values((3, 2, 4)),
    287         weights=np.asarray((5, 11)).reshape((1, 2, 1)))
    288 
    289   def test1xNxNWeighted(self):
    290     self._test_3d_weighted(
    291         _test_values((3, 2, 4)),
    292         weights=np.asarray((5, 7, 11, 3, 2, 13, 7, 5)).reshape((1, 2, 4)))
    293 
    294   def testNx1x1Weighted(self):
    295     self._test_3d_weighted(
    296         _test_values((3, 2, 4)),
    297         weights=np.asarray((5, 7, 11)).reshape((3, 1, 1)))
    298 
    299   def testNx1xNWeighted(self):
    300     self._test_3d_weighted(
    301         _test_values((3, 2, 4)),
    302         weights=np.asarray((
    303             5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3)).reshape((3, 1, 4)))
    304 
    305   def testNxNxNWeighted(self):
    306     self._test_3d_weighted(
    307         _test_values((3, 2, 4)),
    308         weights=np.asarray((
    309             5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3,
    310             2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4)))
    311 
    312   def testInvalidWeights(self):
    313     values_placeholder = array_ops.placeholder(dtype=dtypes_lib.float32)
    314     values = _test_values((3, 2, 4, 1))
    315     invalid_weights = (
    316         (1,),
    317         (1, 1),
    318         (3, 2),
    319         (2, 4, 1),
    320         (4, 2, 4, 1),
    321         (3, 3, 4, 1),
    322         (3, 2, 5, 1),
    323         (3, 2, 4, 2),
    324         (1, 1, 1, 1, 1))
    325     expected_error_msg = 'weights can not be broadcast to values'
    326     for invalid_weight in invalid_weights:
    327       # Static shapes.
    328       with self.assertRaisesRegexp(ValueError, expected_error_msg):
    329         metrics.mean(values, invalid_weight)
    330 
    331       # Dynamic shapes.
    332       with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg):
    333         with self.test_session():
    334           _, update_op = metrics.mean(values_placeholder, invalid_weight)
    335           variables.local_variables_initializer().run()
    336           update_op.eval(feed_dict={values_placeholder: values})
    337 
    338 
    339 class MeanTensorTest(test.TestCase):
    340 
    341   def setUp(self):
    342     ops.reset_default_graph()
    343 
    344   def testVars(self):
    345     metrics.mean_tensor(array_ops.ones([4, 3]))
    346     _assert_metric_variables(self,
    347                              ('mean/total_tensor:0', 'mean/count_tensor:0'))
    348 
    349   def testMetricsCollection(self):
    350     my_collection_name = '__metrics__'
    351     mean, _ = metrics.mean_tensor(
    352         array_ops.ones([4, 3]), metrics_collections=[my_collection_name])
    353     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
    354 
    355   def testUpdatesCollection(self):
    356     my_collection_name = '__updates__'
    357     _, update_op = metrics.mean_tensor(
    358         array_ops.ones([4, 3]), updates_collections=[my_collection_name])
    359     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
    360 
    361   def testBasic(self):
    362     with self.test_session() as sess:
    363       values_queue = data_flow_ops.FIFOQueue(
    364           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    365       _enqueue_vector(sess, values_queue, [0, 1])
    366       _enqueue_vector(sess, values_queue, [-4.2, 9.1])
    367       _enqueue_vector(sess, values_queue, [6.5, 0])
    368       _enqueue_vector(sess, values_queue, [-3.2, 4.0])
    369       values = values_queue.dequeue()
    370 
    371       mean, update_op = metrics.mean_tensor(values)
    372 
    373       sess.run(variables.local_variables_initializer())
    374       for _ in range(4):
    375         sess.run(update_op)
    376       self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean))
    377 
    378   def testMultiDimensional(self):
    379     with self.test_session() as sess:
    380       values_queue = data_flow_ops.FIFOQueue(
    381           2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2))
    382       _enqueue_vector(
    383           sess,
    384           values_queue, [[[1, 2], [1, 2]], [[1, 2], [1, 2]]],
    385           shape=(2, 2, 2))
    386       _enqueue_vector(
    387           sess,
    388           values_queue, [[[1, 2], [1, 2]], [[3, 4], [9, 10]]],
    389           shape=(2, 2, 2))
    390       values = values_queue.dequeue()
    391 
    392       mean, update_op = metrics.mean_tensor(values)
    393 
    394       sess.run(variables.local_variables_initializer())
    395       for _ in range(2):
    396         sess.run(update_op)
    397       self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean))
    398 
    399   def testUpdateOpsReturnsCurrentValue(self):
    400     with self.test_session() as sess:
    401       values_queue = data_flow_ops.FIFOQueue(
    402           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    403       _enqueue_vector(sess, values_queue, [0, 1])
    404       _enqueue_vector(sess, values_queue, [-4.2, 9.1])
    405       _enqueue_vector(sess, values_queue, [6.5, 0])
    406       _enqueue_vector(sess, values_queue, [-3.2, 4.0])
    407       values = values_queue.dequeue()
    408 
    409       mean, update_op = metrics.mean_tensor(values)
    410 
    411       sess.run(variables.local_variables_initializer())
    412 
    413       self.assertAllClose([[0, 1]], sess.run(update_op), 5)
    414       self.assertAllClose([[-2.1, 5.05]], sess.run(update_op), 5)
    415       self.assertAllClose([[2.3 / 3., 10.1 / 3.]], sess.run(update_op), 5)
    416       self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(update_op), 5)
    417 
    418       self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5)
    419 
    420   def testWeighted1d(self):
    421     with self.test_session() as sess:
    422       # Create the queue that populates the values.
    423       values_queue = data_flow_ops.FIFOQueue(
    424           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    425       _enqueue_vector(sess, values_queue, [0, 1])
    426       _enqueue_vector(sess, values_queue, [-4.2, 9.1])
    427       _enqueue_vector(sess, values_queue, [6.5, 0])
    428       _enqueue_vector(sess, values_queue, [-3.2, 4.0])
    429       values = values_queue.dequeue()
    430 
    431       # Create the queue that populates the weights.
    432       weights_queue = data_flow_ops.FIFOQueue(
    433           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
    434       _enqueue_vector(sess, weights_queue, [[1]])
    435       _enqueue_vector(sess, weights_queue, [[0]])
    436       _enqueue_vector(sess, weights_queue, [[1]])
    437       _enqueue_vector(sess, weights_queue, [[0]])
    438       weights = weights_queue.dequeue()
    439 
    440       mean, update_op = metrics.mean_tensor(values, weights)
    441 
    442       sess.run(variables.local_variables_initializer())
    443       for _ in range(4):
    444         sess.run(update_op)
    445       self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5)
    446 
    447   def testWeighted2d_1(self):
    448     with self.test_session() as sess:
    449       # Create the queue that populates the values.
    450       values_queue = data_flow_ops.FIFOQueue(
    451           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    452       _enqueue_vector(sess, values_queue, [0, 1])
    453       _enqueue_vector(sess, values_queue, [-4.2, 9.1])
    454       _enqueue_vector(sess, values_queue, [6.5, 0])
    455       _enqueue_vector(sess, values_queue, [-3.2, 4.0])
    456       values = values_queue.dequeue()
    457 
    458       # Create the queue that populates the weights.
    459       weights_queue = data_flow_ops.FIFOQueue(
    460           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    461       _enqueue_vector(sess, weights_queue, [1, 1])
    462       _enqueue_vector(sess, weights_queue, [1, 0])
    463       _enqueue_vector(sess, weights_queue, [0, 1])
    464       _enqueue_vector(sess, weights_queue, [0, 0])
    465       weights = weights_queue.dequeue()
    466 
    467       mean, update_op = metrics.mean_tensor(values, weights)
    468 
    469       sess.run(variables.local_variables_initializer())
    470       for _ in range(4):
    471         sess.run(update_op)
    472       self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5)
    473 
    474   def testWeighted2d_2(self):
    475     with self.test_session() as sess:
    476       # Create the queue that populates the values.
    477       values_queue = data_flow_ops.FIFOQueue(
    478           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    479       _enqueue_vector(sess, values_queue, [0, 1])
    480       _enqueue_vector(sess, values_queue, [-4.2, 9.1])
    481       _enqueue_vector(sess, values_queue, [6.5, 0])
    482       _enqueue_vector(sess, values_queue, [-3.2, 4.0])
    483       values = values_queue.dequeue()
    484 
    485       # Create the queue that populates the weights.
    486       weights_queue = data_flow_ops.FIFOQueue(
    487           4, dtypes=dtypes_lib.float32, shapes=(1, 2))
    488       _enqueue_vector(sess, weights_queue, [0, 1])
    489       _enqueue_vector(sess, weights_queue, [0, 0])
    490       _enqueue_vector(sess, weights_queue, [0, 1])
    491       _enqueue_vector(sess, weights_queue, [0, 0])
    492       weights = weights_queue.dequeue()
    493 
    494       mean, update_op = metrics.mean_tensor(values, weights)
    495 
    496       sess.run(variables.local_variables_initializer())
    497       for _ in range(4):
    498         sess.run(update_op)
    499       self.assertAllClose([[0, 0.5]], sess.run(mean), 5)
    500 
    501 
    502 class AccuracyTest(test.TestCase):
    503 
    504   def setUp(self):
    505     ops.reset_default_graph()
    506 
    507   def testVars(self):
    508     metrics.accuracy(
    509         predictions=array_ops.ones((10, 1)),
    510         labels=array_ops.ones((10, 1)),
    511         name='my_accuracy')
    512     _assert_metric_variables(self,
    513                              ('my_accuracy/count:0', 'my_accuracy/total:0'))
    514 
    515   def testMetricsCollection(self):
    516     my_collection_name = '__metrics__'
    517     mean, _ = metrics.accuracy(
    518         predictions=array_ops.ones((10, 1)),
    519         labels=array_ops.ones((10, 1)),
    520         metrics_collections=[my_collection_name])
    521     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
    522 
    523   def testUpdatesCollection(self):
    524     my_collection_name = '__updates__'
    525     _, update_op = metrics.accuracy(
    526         predictions=array_ops.ones((10, 1)),
    527         labels=array_ops.ones((10, 1)),
    528         updates_collections=[my_collection_name])
    529     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
    530 
    531   def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
    532     predictions = array_ops.ones((10, 3))
    533     labels = array_ops.ones((10, 4))
    534     with self.assertRaises(ValueError):
    535       metrics.accuracy(labels, predictions)
    536 
    537   def testPredictionsAndWeightsOfDifferentSizeRaisesValueError(self):
    538     predictions = array_ops.ones((10, 3))
    539     labels = array_ops.ones((10, 3))
    540     weights = array_ops.ones((9, 3))
    541     with self.assertRaises(ValueError):
    542       metrics.accuracy(labels, predictions, weights)
    543 
    544   def testValueTensorIsIdempotent(self):
    545     predictions = random_ops.random_uniform(
    546         (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
    547     labels = random_ops.random_uniform(
    548         (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1)
    549     accuracy, update_op = metrics.accuracy(labels, predictions)
    550 
    551     with self.test_session() as sess:
    552       sess.run(variables.local_variables_initializer())
    553 
    554       # Run several updates.
    555       for _ in range(10):
    556         sess.run(update_op)
    557 
    558       # Then verify idempotency.
    559       initial_accuracy = accuracy.eval()
    560       for _ in range(10):
    561         self.assertEqual(initial_accuracy, accuracy.eval())
    562 
    563   def testMultipleUpdates(self):
    564     with self.test_session() as sess:
    565       # Create the queue that populates the predictions.
    566       preds_queue = data_flow_ops.FIFOQueue(
    567           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
    568       _enqueue_vector(sess, preds_queue, [0])
    569       _enqueue_vector(sess, preds_queue, [1])
    570       _enqueue_vector(sess, preds_queue, [2])
    571       _enqueue_vector(sess, preds_queue, [1])
    572       predictions = preds_queue.dequeue()
    573 
    574       # Create the queue that populates the labels.
    575       labels_queue = data_flow_ops.FIFOQueue(
    576           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
    577       _enqueue_vector(sess, labels_queue, [0])
    578       _enqueue_vector(sess, labels_queue, [1])
    579       _enqueue_vector(sess, labels_queue, [1])
    580       _enqueue_vector(sess, labels_queue, [2])
    581       labels = labels_queue.dequeue()
    582 
    583       accuracy, update_op = metrics.accuracy(labels, predictions)
    584 
    585       sess.run(variables.local_variables_initializer())
    586       for _ in xrange(3):
    587         sess.run(update_op)
    588       self.assertEqual(0.5, sess.run(update_op))
    589       self.assertEqual(0.5, accuracy.eval())
    590 
    591   def testEffectivelyEquivalentSizes(self):
    592     predictions = array_ops.ones((40, 1))
    593     labels = array_ops.ones((40,))
    594     with self.test_session() as sess:
    595       accuracy, update_op = metrics.accuracy(labels, predictions)
    596 
    597       sess.run(variables.local_variables_initializer())
    598       self.assertEqual(1.0, update_op.eval())
    599       self.assertEqual(1.0, accuracy.eval())
    600 
    601   def testEffectivelyEquivalentSizesWithScalarWeight(self):
    602     predictions = array_ops.ones((40, 1))
    603     labels = array_ops.ones((40,))
    604     with self.test_session() as sess:
    605       accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0)
    606 
    607       sess.run(variables.local_variables_initializer())
    608       self.assertEqual(1.0, update_op.eval())
    609       self.assertEqual(1.0, accuracy.eval())
    610 
    611   def testEffectivelyEquivalentSizesWithStaticShapedWeight(self):
    612     predictions = ops.convert_to_tensor([1, 1, 1])  # shape 3,
    613     labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]),
    614                                    1)  # shape 3, 1
    615     weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]),
    616                                     1)  # shape 3, 1
    617 
    618     with self.test_session() as sess:
    619       accuracy, update_op = metrics.accuracy(labels, predictions, weights)
    620 
    621       sess.run(variables.local_variables_initializer())
    622       # if streaming_accuracy does not flatten the weight, accuracy would be
    623       # 0.33333334 due to an intended broadcast of weight. Due to flattening,
    624       # it will be higher than .95
    625       self.assertGreater(update_op.eval(), .95)
    626       self.assertGreater(accuracy.eval(), .95)
    627 
    628   def testEffectivelyEquivalentSizesWithDynamicallyShapedWeight(self):
    629     predictions = ops.convert_to_tensor([1, 1, 1])  # shape 3,
    630     labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]),
    631                                    1)  # shape 3, 1
    632 
    633     weights = [[100], [1], [1]]  # shape 3, 1
    634     weights_placeholder = array_ops.placeholder(
    635         dtype=dtypes_lib.int32, name='weights')
    636     feed_dict = {weights_placeholder: weights}
    637 
    638     with self.test_session() as sess:
    639       accuracy, update_op = metrics.accuracy(labels, predictions,
    640                                              weights_placeholder)
    641 
    642       sess.run(variables.local_variables_initializer())
    643       # if streaming_accuracy does not flatten the weight, accuracy would be
    644       # 0.33333334 due to an intended broadcast of weight. Due to flattening,
    645       # it will be higher than .95
    646       self.assertGreater(update_op.eval(feed_dict=feed_dict), .95)
    647       self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95)
    648 
    649   def testMultipleUpdatesWithWeightedValues(self):
    650     with self.test_session() as sess:
    651       # Create the queue that populates the predictions.
    652       preds_queue = data_flow_ops.FIFOQueue(
    653           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
    654       _enqueue_vector(sess, preds_queue, [0])
    655       _enqueue_vector(sess, preds_queue, [1])
    656       _enqueue_vector(sess, preds_queue, [2])
    657       _enqueue_vector(sess, preds_queue, [1])
    658       predictions = preds_queue.dequeue()
    659 
    660       # Create the queue that populates the labels.
    661       labels_queue = data_flow_ops.FIFOQueue(
    662           4, dtypes=dtypes_lib.float32, shapes=(1, 1))
    663       _enqueue_vector(sess, labels_queue, [0])
    664       _enqueue_vector(sess, labels_queue, [1])
    665       _enqueue_vector(sess, labels_queue, [1])
    666       _enqueue_vector(sess, labels_queue, [2])
    667       labels = labels_queue.dequeue()
    668 
    669       # Create the queue that populates the weights.
    670       weights_queue = data_flow_ops.FIFOQueue(
    671           4, dtypes=dtypes_lib.int64, shapes=(1, 1))
    672       _enqueue_vector(sess, weights_queue, [1])
    673       _enqueue_vector(sess, weights_queue, [1])
    674       _enqueue_vector(sess, weights_queue, [0])
    675       _enqueue_vector(sess, weights_queue, [0])
    676       weights = weights_queue.dequeue()
    677 
    678       accuracy, update_op = metrics.accuracy(labels, predictions, weights)
    679 
    680       sess.run(variables.local_variables_initializer())
    681       for _ in xrange(3):
    682         sess.run(update_op)
    683       self.assertEqual(1.0, sess.run(update_op))
    684       self.assertEqual(1.0, accuracy.eval())
    685 
    686 
    687 class PrecisionTest(test.TestCase):
    688 
    689   def setUp(self):
    690     np.random.seed(1)
    691     ops.reset_default_graph()
    692 
    693   def testVars(self):
    694     metrics.precision(
    695         predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
    696     _assert_metric_variables(self, ('precision/false_positives/count:0',
    697                                     'precision/true_positives/count:0'))
    698 
    699   def testMetricsCollection(self):
    700     my_collection_name = '__metrics__'
    701     mean, _ = metrics.precision(
    702         predictions=array_ops.ones((10, 1)),
    703         labels=array_ops.ones((10, 1)),
    704         metrics_collections=[my_collection_name])
    705     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
    706 
    707   def testUpdatesCollection(self):
    708     my_collection_name = '__updates__'
    709     _, update_op = metrics.precision(
    710         predictions=array_ops.ones((10, 1)),
    711         labels=array_ops.ones((10, 1)),
    712         updates_collections=[my_collection_name])
    713     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
    714 
    715   def testValueTensorIsIdempotent(self):
    716     predictions = random_ops.random_uniform(
    717         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    718     labels = random_ops.random_uniform(
    719         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    720     precision, update_op = metrics.precision(labels, predictions)
    721 
    722     with self.test_session() as sess:
    723       sess.run(variables.local_variables_initializer())
    724 
    725       # Run several updates.
    726       for _ in range(10):
    727         sess.run(update_op)
    728 
    729       # Then verify idempotency.
    730       initial_precision = precision.eval()
    731       for _ in range(10):
    732         self.assertEqual(initial_precision, precision.eval())
    733 
    734   def testAllCorrect(self):
    735     inputs = np.random.randint(0, 2, size=(100, 1))
    736 
    737     predictions = constant_op.constant(inputs)
    738     labels = constant_op.constant(inputs)
    739     precision, update_op = metrics.precision(labels, predictions)
    740 
    741     with self.test_session() as sess:
    742       sess.run(variables.local_variables_initializer())
    743       self.assertAlmostEqual(1, sess.run(update_op))
    744       self.assertAlmostEqual(1, precision.eval())
    745 
    746   def testSomeCorrect_multipleInputDtypes(self):
    747     for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
    748       predictions = math_ops.cast(
    749           constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype)
    750       labels = math_ops.cast(
    751           constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
    752       precision, update_op = metrics.precision(labels, predictions)
    753 
    754       with self.test_session() as sess:
    755         sess.run(variables.local_variables_initializer())
    756         self.assertAlmostEqual(0.5, update_op.eval())
    757         self.assertAlmostEqual(0.5, precision.eval())
    758 
    759   def testWeighted1d(self):
    760     predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
    761     labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
    762     precision, update_op = metrics.precision(
    763         labels, predictions, weights=constant_op.constant([[2], [5]]))
    764 
    765     with self.test_session():
    766       variables.local_variables_initializer().run()
    767       weighted_tp = 2.0 + 5.0
    768       weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
    769       expected_precision = weighted_tp / weighted_positives
    770       self.assertAlmostEqual(expected_precision, update_op.eval())
    771       self.assertAlmostEqual(expected_precision, precision.eval())
    772 
    773   def testWeightedScalar_placeholders(self):
    774     predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
    775     labels = array_ops.placeholder(dtype=dtypes_lib.float32)
    776     feed_dict = {
    777         predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
    778         labels: ((0, 1, 1, 0), (1, 0, 0, 1))
    779     }
    780     precision, update_op = metrics.precision(labels, predictions, weights=2)
    781 
    782     with self.test_session():
    783       variables.local_variables_initializer().run()
    784       weighted_tp = 2.0 + 2.0
    785       weighted_positives = (2.0 + 2.0) + (2.0 + 2.0)
    786       expected_precision = weighted_tp / weighted_positives
    787       self.assertAlmostEqual(
    788           expected_precision, update_op.eval(feed_dict=feed_dict))
    789       self.assertAlmostEqual(
    790           expected_precision, precision.eval(feed_dict=feed_dict))
    791 
    792   def testWeighted1d_placeholders(self):
    793     predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
    794     labels = array_ops.placeholder(dtype=dtypes_lib.float32)
    795     feed_dict = {
    796         predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
    797         labels: ((0, 1, 1, 0), (1, 0, 0, 1))
    798     }
    799     precision, update_op = metrics.precision(
    800         labels, predictions, weights=constant_op.constant([[2], [5]]))
    801 
    802     with self.test_session():
    803       variables.local_variables_initializer().run()
    804       weighted_tp = 2.0 + 5.0
    805       weighted_positives = (2.0 + 2.0) + (5.0 + 5.0)
    806       expected_precision = weighted_tp / weighted_positives
    807       self.assertAlmostEqual(
    808           expected_precision, update_op.eval(feed_dict=feed_dict))
    809       self.assertAlmostEqual(
    810           expected_precision, precision.eval(feed_dict=feed_dict))
    811 
    812   def testWeighted2d(self):
    813     predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]])
    814     labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
    815     precision, update_op = metrics.precision(
    816         labels,
    817         predictions,
    818         weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
    819 
    820     with self.test_session():
    821       variables.local_variables_initializer().run()
    822       weighted_tp = 3.0 + 4.0
    823       weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
    824       expected_precision = weighted_tp / weighted_positives
    825       self.assertAlmostEqual(expected_precision, update_op.eval())
    826       self.assertAlmostEqual(expected_precision, precision.eval())
    827 
    828   def testWeighted2d_placeholders(self):
    829     predictions = array_ops.placeholder(dtype=dtypes_lib.float32)
    830     labels = array_ops.placeholder(dtype=dtypes_lib.float32)
    831     feed_dict = {
    832         predictions: ((1, 0, 1, 0), (1, 0, 1, 0)),
    833         labels: ((0, 1, 1, 0), (1, 0, 0, 1))
    834     }
    835     precision, update_op = metrics.precision(
    836         labels,
    837         predictions,
    838         weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]))
    839 
    840     with self.test_session():
    841       variables.local_variables_initializer().run()
    842       weighted_tp = 3.0 + 4.0
    843       weighted_positives = (1.0 + 3.0) + (4.0 + 2.0)
    844       expected_precision = weighted_tp / weighted_positives
    845       self.assertAlmostEqual(
    846           expected_precision, update_op.eval(feed_dict=feed_dict))
    847       self.assertAlmostEqual(
    848           expected_precision, precision.eval(feed_dict=feed_dict))
    849 
    850   def testAllIncorrect(self):
    851     inputs = np.random.randint(0, 2, size=(100, 1))
    852 
    853     predictions = constant_op.constant(inputs)
    854     labels = constant_op.constant(1 - inputs)
    855     precision, update_op = metrics.precision(labels, predictions)
    856 
    857     with self.test_session() as sess:
    858       sess.run(variables.local_variables_initializer())
    859       sess.run(update_op)
    860       self.assertAlmostEqual(0, precision.eval())
    861 
    862   def testZeroTrueAndFalsePositivesGivesZeroPrecision(self):
    863     predictions = constant_op.constant([0, 0, 0, 0])
    864     labels = constant_op.constant([0, 0, 0, 0])
    865     precision, update_op = metrics.precision(labels, predictions)
    866 
    867     with self.test_session() as sess:
    868       sess.run(variables.local_variables_initializer())
    869       sess.run(update_op)
    870       self.assertEqual(0.0, precision.eval())
    871 
    872 
    873 class RecallTest(test.TestCase):
    874 
    875   def setUp(self):
    876     np.random.seed(1)
    877     ops.reset_default_graph()
    878 
    879   def testVars(self):
    880     metrics.recall(
    881         predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
    882     _assert_metric_variables(
    883         self,
    884         ('recall/false_negatives/count:0', 'recall/true_positives/count:0'))
    885 
    886   def testMetricsCollection(self):
    887     my_collection_name = '__metrics__'
    888     mean, _ = metrics.recall(
    889         predictions=array_ops.ones((10, 1)),
    890         labels=array_ops.ones((10, 1)),
    891         metrics_collections=[my_collection_name])
    892     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
    893 
    894   def testUpdatesCollection(self):
    895     my_collection_name = '__updates__'
    896     _, update_op = metrics.recall(
    897         predictions=array_ops.ones((10, 1)),
    898         labels=array_ops.ones((10, 1)),
    899         updates_collections=[my_collection_name])
    900     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
    901 
    902   def testValueTensorIsIdempotent(self):
    903     predictions = random_ops.random_uniform(
    904         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    905     labels = random_ops.random_uniform(
    906         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
    907     recall, update_op = metrics.recall(labels, predictions)
    908 
    909     with self.test_session() as sess:
    910       sess.run(variables.local_variables_initializer())
    911 
    912       # Run several updates.
    913       for _ in range(10):
    914         sess.run(update_op)
    915 
    916       # Then verify idempotency.
    917       initial_recall = recall.eval()
    918       for _ in range(10):
    919         self.assertEqual(initial_recall, recall.eval())
    920 
    921   def testAllCorrect(self):
    922     np_inputs = np.random.randint(0, 2, size=(100, 1))
    923 
    924     predictions = constant_op.constant(np_inputs)
    925     labels = constant_op.constant(np_inputs)
    926     recall, update_op = metrics.recall(labels, predictions)
    927 
    928     with self.test_session() as sess:
    929       sess.run(variables.local_variables_initializer())
    930       sess.run(update_op)
    931       self.assertEqual(1, recall.eval())
    932 
    933   def testSomeCorrect_multipleInputDtypes(self):
    934     for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
    935       predictions = math_ops.cast(
    936           constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype)
    937       labels = math_ops.cast(
    938           constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype)
    939       recall, update_op = metrics.recall(labels, predictions)
    940 
    941       with self.test_session() as sess:
    942         sess.run(variables.local_variables_initializer())
    943         self.assertAlmostEqual(0.5, update_op.eval())
    944         self.assertAlmostEqual(0.5, recall.eval())
    945 
    946   def testWeighted1d(self):
    947     predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
    948     labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
    949     weights = constant_op.constant([[2], [5]])
    950     recall, update_op = metrics.recall(labels, predictions, weights=weights)
    951 
    952     with self.test_session() as sess:
    953       sess.run(variables.local_variables_initializer())
    954       weighted_tp = 2.0 + 5.0
    955       weighted_t = (2.0 + 2.0) + (5.0 + 5.0)
    956       expected_precision = weighted_tp / weighted_t
    957       self.assertAlmostEqual(expected_precision, update_op.eval())
    958       self.assertAlmostEqual(expected_precision, recall.eval())
    959 
    960   def testWeighted2d(self):
    961     predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]])
    962     labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]])
    963     weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])
    964     recall, update_op = metrics.recall(labels, predictions, weights=weights)
    965 
    966     with self.test_session() as sess:
    967       sess.run(variables.local_variables_initializer())
    968       weighted_tp = 3.0 + 1.0
    969       weighted_t = (2.0 + 3.0) + (4.0 + 1.0)
    970       expected_precision = weighted_tp / weighted_t
    971       self.assertAlmostEqual(expected_precision, update_op.eval())
    972       self.assertAlmostEqual(expected_precision, recall.eval())
    973 
    974   def testAllIncorrect(self):
    975     np_inputs = np.random.randint(0, 2, size=(100, 1))
    976 
    977     predictions = constant_op.constant(np_inputs)
    978     labels = constant_op.constant(1 - np_inputs)
    979     recall, update_op = metrics.recall(labels, predictions)
    980 
    981     with self.test_session() as sess:
    982       sess.run(variables.local_variables_initializer())
    983       sess.run(update_op)
    984       self.assertEqual(0, recall.eval())
    985 
    986   def testZeroTruePositivesAndFalseNegativesGivesZeroRecall(self):
    987     predictions = array_ops.zeros((1, 4))
    988     labels = array_ops.zeros((1, 4))
    989     recall, update_op = metrics.recall(labels, predictions)
    990 
    991     with self.test_session() as sess:
    992       sess.run(variables.local_variables_initializer())
    993       sess.run(update_op)
    994       self.assertEqual(0, recall.eval())
    995 
    996 
    997 class AUCTest(test.TestCase):
    998 
    999   def setUp(self):
   1000     np.random.seed(1)
   1001     ops.reset_default_graph()
   1002 
   1003   def testVars(self):
   1004     metrics.auc(predictions=array_ops.ones((10, 1)),
   1005                 labels=array_ops.ones((10, 1)))
   1006     _assert_metric_variables(self,
   1007                              ('auc/true_positives:0', 'auc/false_negatives:0',
   1008                               'auc/false_positives:0', 'auc/true_negatives:0'))
   1009 
   1010   def testMetricsCollection(self):
   1011     my_collection_name = '__metrics__'
   1012     mean, _ = metrics.auc(predictions=array_ops.ones((10, 1)),
   1013                           labels=array_ops.ones((10, 1)),
   1014                           metrics_collections=[my_collection_name])
   1015     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   1016 
   1017   def testUpdatesCollection(self):
   1018     my_collection_name = '__updates__'
   1019     _, update_op = metrics.auc(predictions=array_ops.ones((10, 1)),
   1020                                labels=array_ops.ones((10, 1)),
   1021                                updates_collections=[my_collection_name])
   1022     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   1023 
   1024   def testValueTensorIsIdempotent(self):
   1025     predictions = random_ops.random_uniform(
   1026         (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
   1027     labels = random_ops.random_uniform(
   1028         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
   1029     auc, update_op = metrics.auc(labels, predictions)
   1030 
   1031     with self.test_session() as sess:
   1032       sess.run(variables.local_variables_initializer())
   1033 
   1034       # Run several updates.
   1035       for _ in range(10):
   1036         sess.run(update_op)
   1037 
   1038       # Then verify idempotency.
   1039       initial_auc = auc.eval()
   1040       for _ in range(10):
   1041         self.assertAlmostEqual(initial_auc, auc.eval(), 5)
   1042 
   1043   def testAllCorrect(self):
   1044     self.allCorrectAsExpected('ROC')
   1045 
   1046   def allCorrectAsExpected(self, curve):
   1047     inputs = np.random.randint(0, 2, size=(100, 1))
   1048 
   1049     with self.test_session() as sess:
   1050       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
   1051       labels = constant_op.constant(inputs)
   1052       auc, update_op = metrics.auc(labels, predictions, curve=curve)
   1053 
   1054       sess.run(variables.local_variables_initializer())
   1055       self.assertEqual(1, sess.run(update_op))
   1056 
   1057       self.assertEqual(1, auc.eval())
   1058 
   1059   def testSomeCorrect_multipleLabelDtypes(self):
   1060     with self.test_session() as sess:
   1061       for label_dtype in (
   1062           dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
   1063         predictions = constant_op.constant(
   1064             [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
   1065         labels = math_ops.cast(
   1066             constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype)
   1067         auc, update_op = metrics.auc(labels, predictions)
   1068 
   1069         sess.run(variables.local_variables_initializer())
   1070         self.assertAlmostEqual(0.5, sess.run(update_op))
   1071 
   1072         self.assertAlmostEqual(0.5, auc.eval())
   1073 
   1074   def testWeighted1d(self):
   1075     with self.test_session() as sess:
   1076       predictions = constant_op.constant(
   1077           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
   1078       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
   1079       weights = constant_op.constant([2], shape=(1, 1))
   1080       auc, update_op = metrics.auc(labels, predictions, weights=weights)
   1081 
   1082       sess.run(variables.local_variables_initializer())
   1083       self.assertAlmostEqual(0.5, sess.run(update_op), 5)
   1084 
   1085       self.assertAlmostEqual(0.5, auc.eval(), 5)
   1086 
   1087   def testWeighted2d(self):
   1088     with self.test_session() as sess:
   1089       predictions = constant_op.constant(
   1090           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
   1091       labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4))
   1092       weights = constant_op.constant([1, 2, 3, 4], shape=(1, 4))
   1093       auc, update_op = metrics.auc(labels, predictions, weights=weights)
   1094 
   1095       sess.run(variables.local_variables_initializer())
   1096       self.assertAlmostEqual(0.7, sess.run(update_op), 5)
   1097 
   1098       self.assertAlmostEqual(0.7, auc.eval(), 5)
   1099 
   1100   def testAUCPRSpecialCase(self):
   1101     with self.test_session() as sess:
   1102       predictions = constant_op.constant(
   1103           [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32)
   1104       labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4))
   1105       auc, update_op = metrics.auc(labels, predictions, curve='PR')
   1106 
   1107       sess.run(variables.local_variables_initializer())
   1108       self.assertAlmostEqual(0.54166, sess.run(update_op), delta=1e-3)
   1109 
   1110       self.assertAlmostEqual(0.54166, auc.eval(), delta=1e-3)
   1111 
   1112   def testAnotherAUCPRSpecialCase(self):
   1113     with self.test_session() as sess:
   1114       predictions = constant_op.constant(
   1115           [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81],
   1116           shape=(1, 7),
   1117           dtype=dtypes_lib.float32)
   1118       labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7))
   1119       auc, update_op = metrics.auc(labels, predictions, curve='PR')
   1120 
   1121       sess.run(variables.local_variables_initializer())
   1122       self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3)
   1123 
   1124       self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3)
   1125 
   1126   def testThirdAUCPRSpecialCase(self):
   1127     with self.test_session() as sess:
   1128       predictions = constant_op.constant(
   1129           [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5],
   1130           shape=(1, 7),
   1131           dtype=dtypes_lib.float32)
   1132       labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7))
   1133       auc, update_op = metrics.auc(labels, predictions, curve='PR')
   1134 
   1135       sess.run(variables.local_variables_initializer())
   1136       self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3)
   1137 
   1138       self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3)
   1139 
   1140   def testFourthAUCPRSpecialCase(self):
   1141     # Create the labels and data.
   1142     labels = np.array([
   1143         0, 0, 0, 0, 0, 0, 0, 1, 0, 1])
   1144     predictions = np.array([
   1145         0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35])
   1146 
   1147     with self.test_session() as sess:
   1148       auc, _ = metrics.auc(
   1149           labels, predictions, curve='PR', num_thresholds=11)
   1150 
   1151       sess.run(variables.local_variables_initializer())
   1152       # Since this is only approximate, we can't expect a 6 digits match.
   1153       # Although with higher number of samples/thresholds we should see the
   1154       # accuracy improving
   1155       self.assertAlmostEqual(0.0, auc.eval(), delta=0.001)
   1156 
   1157   def testAllIncorrect(self):
   1158     inputs = np.random.randint(0, 2, size=(100, 1))
   1159 
   1160     with self.test_session() as sess:
   1161       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
   1162       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
   1163       auc, update_op = metrics.auc(labels, predictions)
   1164 
   1165       sess.run(variables.local_variables_initializer())
   1166       self.assertAlmostEqual(0, sess.run(update_op))
   1167 
   1168       self.assertAlmostEqual(0, auc.eval())
   1169 
   1170   def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self):
   1171     with self.test_session() as sess:
   1172       predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
   1173       labels = array_ops.zeros([4])
   1174       auc, update_op = metrics.auc(labels, predictions)
   1175 
   1176       sess.run(variables.local_variables_initializer())
   1177       self.assertAlmostEqual(1, sess.run(update_op), 6)
   1178 
   1179       self.assertAlmostEqual(1, auc.eval(), 6)
   1180 
   1181   def testRecallOneAndPrecisionOne(self):
   1182     with self.test_session() as sess:
   1183       predictions = array_ops.ones([4], dtype=dtypes_lib.float32)
   1184       labels = array_ops.ones([4])
   1185       auc, update_op = metrics.auc(labels, predictions, curve='PR')
   1186 
   1187       sess.run(variables.local_variables_initializer())
   1188       self.assertAlmostEqual(0.5, sess.run(update_op), 6)
   1189 
   1190       self.assertAlmostEqual(0.5, auc.eval(), 6)
   1191 
   1192   def np_auc(self, predictions, labels, weights):
   1193     """Computes the AUC explicitly using Numpy.
   1194 
   1195     Args:
   1196       predictions: an ndarray with shape [N].
   1197       labels: an ndarray with shape [N].
   1198       weights: an ndarray with shape [N].
   1199 
   1200     Returns:
   1201       the area under the ROC curve.
   1202     """
   1203     if weights is None:
   1204       weights = np.ones(np.size(predictions))
   1205     is_positive = labels > 0
   1206     num_positives = np.sum(weights[is_positive])
   1207     num_negatives = np.sum(weights[~is_positive])
   1208 
   1209     # Sort descending:
   1210     inds = np.argsort(-predictions)
   1211 
   1212     sorted_labels = labels[inds]
   1213     sorted_weights = weights[inds]
   1214     is_positive = sorted_labels > 0
   1215 
   1216     tp = np.cumsum(sorted_weights * is_positive) / num_positives
   1217     return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives
   1218 
   1219   def testWithMultipleUpdates(self):
   1220     num_samples = 1000
   1221     batch_size = 10
   1222     num_batches = int(num_samples / batch_size)
   1223 
   1224     # Create the labels and data.
   1225     labels = np.random.randint(0, 2, size=num_samples)
   1226     noise = np.random.normal(0.0, scale=0.2, size=num_samples)
   1227     predictions = 0.4 + 0.2 * labels + noise
   1228     predictions[predictions > 1] = 1
   1229     predictions[predictions < 0] = 0
   1230 
   1231     def _enqueue_as_batches(x, enqueue_ops):
   1232       x_batches = x.astype(np.float32).reshape((num_batches, batch_size))
   1233       x_queue = data_flow_ops.FIFOQueue(
   1234           num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
   1235       for i in range(num_batches):
   1236         enqueue_ops[i].append(x_queue.enqueue(x_batches[i, :]))
   1237       return x_queue.dequeue()
   1238 
   1239     for weights in (None, np.ones(num_samples), np.random.exponential(
   1240         scale=1.0, size=num_samples)):
   1241       expected_auc = self.np_auc(predictions, labels, weights)
   1242 
   1243       with self.test_session() as sess:
   1244         enqueue_ops = [[] for i in range(num_batches)]
   1245         tf_predictions = _enqueue_as_batches(predictions, enqueue_ops)
   1246         tf_labels = _enqueue_as_batches(labels, enqueue_ops)
   1247         tf_weights = (_enqueue_as_batches(weights, enqueue_ops) if
   1248                       weights is not None else None)
   1249 
   1250         for i in range(num_batches):
   1251           sess.run(enqueue_ops[i])
   1252 
   1253         auc, update_op = metrics.auc(tf_labels,
   1254                                      tf_predictions,
   1255                                      curve='ROC',
   1256                                      num_thresholds=500,
   1257                                      weights=tf_weights)
   1258 
   1259         sess.run(variables.local_variables_initializer())
   1260         for i in range(num_batches):
   1261           sess.run(update_op)
   1262 
   1263         # Since this is only approximate, we can't expect a 6 digits match.
   1264         # Although with higher number of samples/thresholds we should see the
   1265         # accuracy improving
   1266         self.assertAlmostEqual(expected_auc, auc.eval(), 2)
   1267 
   1268 
   1269 class SpecificityAtSensitivityTest(test.TestCase):
   1270 
   1271   def setUp(self):
   1272     np.random.seed(1)
   1273     ops.reset_default_graph()
   1274 
   1275   def testVars(self):
   1276     metrics.specificity_at_sensitivity(
   1277         predictions=array_ops.ones((10, 1)),
   1278         labels=array_ops.ones((10, 1)),
   1279         sensitivity=0.7)
   1280     _assert_metric_variables(self,
   1281                              ('specificity_at_sensitivity/true_positives:0',
   1282                               'specificity_at_sensitivity/false_negatives:0',
   1283                               'specificity_at_sensitivity/false_positives:0',
   1284                               'specificity_at_sensitivity/true_negatives:0'))
   1285 
   1286   def testMetricsCollection(self):
   1287     my_collection_name = '__metrics__'
   1288     mean, _ = metrics.specificity_at_sensitivity(
   1289         predictions=array_ops.ones((10, 1)),
   1290         labels=array_ops.ones((10, 1)),
   1291         sensitivity=0.7,
   1292         metrics_collections=[my_collection_name])
   1293     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   1294 
   1295   def testUpdatesCollection(self):
   1296     my_collection_name = '__updates__'
   1297     _, update_op = metrics.specificity_at_sensitivity(
   1298         predictions=array_ops.ones((10, 1)),
   1299         labels=array_ops.ones((10, 1)),
   1300         sensitivity=0.7,
   1301         updates_collections=[my_collection_name])
   1302     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   1303 
   1304   def testValueTensorIsIdempotent(self):
   1305     predictions = random_ops.random_uniform(
   1306         (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
   1307     labels = random_ops.random_uniform(
   1308         (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1)
   1309     specificity, update_op = metrics.specificity_at_sensitivity(
   1310         labels, predictions, sensitivity=0.7)
   1311 
   1312     with self.test_session() as sess:
   1313       sess.run(variables.local_variables_initializer())
   1314 
   1315       # Run several updates.
   1316       for _ in range(10):
   1317         sess.run(update_op)
   1318 
   1319       # Then verify idempotency.
   1320       initial_specificity = specificity.eval()
   1321       for _ in range(10):
   1322         self.assertAlmostEqual(initial_specificity, specificity.eval(), 5)
   1323 
   1324   def testAllCorrect(self):
   1325     inputs = np.random.randint(0, 2, size=(100, 1))
   1326 
   1327     predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
   1328     labels = constant_op.constant(inputs)
   1329     specificity, update_op = metrics.specificity_at_sensitivity(
   1330         labels, predictions, sensitivity=0.7)
   1331 
   1332     with self.test_session() as sess:
   1333       sess.run(variables.local_variables_initializer())
   1334       self.assertEqual(1, sess.run(update_op))
   1335       self.assertEqual(1, specificity.eval())
   1336 
   1337   def testSomeCorrectHighSensitivity(self):
   1338     predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.45, 0.5, 0.8, 0.9]
   1339     labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
   1340 
   1341     predictions = constant_op.constant(
   1342         predictions_values, dtype=dtypes_lib.float32)
   1343     labels = constant_op.constant(labels_values)
   1344     specificity, update_op = metrics.specificity_at_sensitivity(
   1345         labels, predictions, sensitivity=0.8)
   1346 
   1347     with self.test_session() as sess:
   1348       sess.run(variables.local_variables_initializer())
   1349       self.assertAlmostEqual(1.0, sess.run(update_op))
   1350       self.assertAlmostEqual(1.0, specificity.eval())
   1351 
   1352   def testSomeCorrectLowSensitivity(self):
   1353     predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
   1354     labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
   1355 
   1356     predictions = constant_op.constant(
   1357         predictions_values, dtype=dtypes_lib.float32)
   1358     labels = constant_op.constant(labels_values)
   1359     specificity, update_op = metrics.specificity_at_sensitivity(
   1360         labels, predictions, sensitivity=0.4)
   1361 
   1362     with self.test_session() as sess:
   1363       sess.run(variables.local_variables_initializer())
   1364 
   1365       self.assertAlmostEqual(0.6, sess.run(update_op))
   1366       self.assertAlmostEqual(0.6, specificity.eval())
   1367 
   1368   def testWeighted1d_multipleLabelDtypes(self):
   1369     for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
   1370       predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
   1371       labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
   1372       weights_values = [3]
   1373 
   1374       predictions = constant_op.constant(
   1375           predictions_values, dtype=dtypes_lib.float32)
   1376       labels = math_ops.cast(labels_values, dtype=label_dtype)
   1377       weights = constant_op.constant(weights_values)
   1378       specificity, update_op = metrics.specificity_at_sensitivity(
   1379           labels, predictions, weights=weights, sensitivity=0.4)
   1380 
   1381       with self.test_session() as sess:
   1382         sess.run(variables.local_variables_initializer())
   1383 
   1384         self.assertAlmostEqual(0.6, sess.run(update_op))
   1385         self.assertAlmostEqual(0.6, specificity.eval())
   1386 
   1387   def testWeighted2d(self):
   1388     predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26]
   1389     labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
   1390     weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
   1391 
   1392     predictions = constant_op.constant(
   1393         predictions_values, dtype=dtypes_lib.float32)
   1394     labels = constant_op.constant(labels_values)
   1395     weights = constant_op.constant(weights_values)
   1396     specificity, update_op = metrics.specificity_at_sensitivity(
   1397         labels, predictions, weights=weights, sensitivity=0.4)
   1398 
   1399     with self.test_session() as sess:
   1400       sess.run(variables.local_variables_initializer())
   1401 
   1402       self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op))
   1403       self.assertAlmostEqual(8.0 / 15.0, specificity.eval())
   1404 
   1405 
   1406 class SensitivityAtSpecificityTest(test.TestCase):
   1407 
   1408   def setUp(self):
   1409     np.random.seed(1)
   1410     ops.reset_default_graph()
   1411 
   1412   def testVars(self):
   1413     metrics.sensitivity_at_specificity(
   1414         predictions=array_ops.ones((10, 1)),
   1415         labels=array_ops.ones((10, 1)),
   1416         specificity=0.7)
   1417     _assert_metric_variables(self,
   1418                              ('sensitivity_at_specificity/true_positives:0',
   1419                               'sensitivity_at_specificity/false_negatives:0',
   1420                               'sensitivity_at_specificity/false_positives:0',
   1421                               'sensitivity_at_specificity/true_negatives:0'))
   1422 
   1423   def testMetricsCollection(self):
   1424     my_collection_name = '__metrics__'
   1425     mean, _ = metrics.sensitivity_at_specificity(
   1426         predictions=array_ops.ones((10, 1)),
   1427         labels=array_ops.ones((10, 1)),
   1428         specificity=0.7,
   1429         metrics_collections=[my_collection_name])
   1430     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   1431 
   1432   def testUpdatesCollection(self):
   1433     my_collection_name = '__updates__'
   1434     _, update_op = metrics.sensitivity_at_specificity(
   1435         predictions=array_ops.ones((10, 1)),
   1436         labels=array_ops.ones((10, 1)),
   1437         specificity=0.7,
   1438         updates_collections=[my_collection_name])
   1439     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   1440 
   1441   def testValueTensorIsIdempotent(self):
   1442     predictions = random_ops.random_uniform(
   1443         (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
   1444     labels = random_ops.random_uniform(
   1445         (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1)
   1446     sensitivity, update_op = metrics.sensitivity_at_specificity(
   1447         labels, predictions, specificity=0.7)
   1448 
   1449     with self.test_session() as sess:
   1450       sess.run(variables.local_variables_initializer())
   1451 
   1452       # Run several updates.
   1453       for _ in range(10):
   1454         sess.run(update_op)
   1455 
   1456       # Then verify idempotency.
   1457       initial_sensitivity = sensitivity.eval()
   1458       for _ in range(10):
   1459         self.assertAlmostEqual(initial_sensitivity, sensitivity.eval(), 5)
   1460 
   1461   def testAllCorrect(self):
   1462     inputs = np.random.randint(0, 2, size=(100, 1))
   1463 
   1464     predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
   1465     labels = constant_op.constant(inputs)
   1466     specificity, update_op = metrics.sensitivity_at_specificity(
   1467         labels, predictions, specificity=0.7)
   1468 
   1469     with self.test_session() as sess:
   1470       sess.run(variables.local_variables_initializer())
   1471       self.assertEqual(1, sess.run(update_op))
   1472       self.assertEqual(1, specificity.eval())
   1473 
   1474   def testSomeCorrectHighSpecificity(self):
   1475     predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9]
   1476     labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
   1477 
   1478     predictions = constant_op.constant(
   1479         predictions_values, dtype=dtypes_lib.float32)
   1480     labels = constant_op.constant(labels_values)
   1481     specificity, update_op = metrics.sensitivity_at_specificity(
   1482         labels, predictions, specificity=0.8)
   1483 
   1484     with self.test_session() as sess:
   1485       sess.run(variables.local_variables_initializer())
   1486       self.assertAlmostEqual(0.8, sess.run(update_op))
   1487       self.assertAlmostEqual(0.8, specificity.eval())
   1488 
   1489   def testSomeCorrectLowSpecificity(self):
   1490     predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
   1491     labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
   1492 
   1493     predictions = constant_op.constant(
   1494         predictions_values, dtype=dtypes_lib.float32)
   1495     labels = constant_op.constant(labels_values)
   1496     specificity, update_op = metrics.sensitivity_at_specificity(
   1497         labels, predictions, specificity=0.4)
   1498 
   1499     with self.test_session() as sess:
   1500       sess.run(variables.local_variables_initializer())
   1501       self.assertAlmostEqual(0.6, sess.run(update_op))
   1502       self.assertAlmostEqual(0.6, specificity.eval())
   1503 
   1504   def testWeighted_multipleLabelDtypes(self):
   1505     for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
   1506       predictions_values = [
   1507           0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26]
   1508       labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1]
   1509       weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10]
   1510 
   1511       predictions = constant_op.constant(
   1512           predictions_values, dtype=dtypes_lib.float32)
   1513       labels = math_ops.cast(labels_values, dtype=label_dtype)
   1514       weights = constant_op.constant(weights_values)
   1515       specificity, update_op = metrics.sensitivity_at_specificity(
   1516           labels, predictions, weights=weights, specificity=0.4)
   1517 
   1518       with self.test_session() as sess:
   1519         sess.run(variables.local_variables_initializer())
   1520         self.assertAlmostEqual(0.675, sess.run(update_op))
   1521         self.assertAlmostEqual(0.675, specificity.eval())
   1522 
   1523 
   1524 # TODO(nsilberman): Break this up into two sets of tests.
   1525 class PrecisionRecallThresholdsTest(test.TestCase):
   1526 
   1527   def setUp(self):
   1528     np.random.seed(1)
   1529     ops.reset_default_graph()
   1530 
   1531   def testVars(self):
   1532     metrics.precision_at_thresholds(
   1533         predictions=array_ops.ones((10, 1)),
   1534         labels=array_ops.ones((10, 1)),
   1535         thresholds=[0, 0.5, 1.0])
   1536     _assert_metric_variables(self, (
   1537         'precision_at_thresholds/true_positives:0',
   1538         'precision_at_thresholds/false_positives:0',
   1539     ))
   1540 
   1541   def testMetricsCollection(self):
   1542     my_collection_name = '__metrics__'
   1543     prec, _ = metrics.precision_at_thresholds(
   1544         predictions=array_ops.ones((10, 1)),
   1545         labels=array_ops.ones((10, 1)),
   1546         thresholds=[0, 0.5, 1.0],
   1547         metrics_collections=[my_collection_name])
   1548     rec, _ = metrics.recall_at_thresholds(
   1549         predictions=array_ops.ones((10, 1)),
   1550         labels=array_ops.ones((10, 1)),
   1551         thresholds=[0, 0.5, 1.0],
   1552         metrics_collections=[my_collection_name])
   1553     self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec])
   1554 
   1555   def testUpdatesCollection(self):
   1556     my_collection_name = '__updates__'
   1557     _, precision_op = metrics.precision_at_thresholds(
   1558         predictions=array_ops.ones((10, 1)),
   1559         labels=array_ops.ones((10, 1)),
   1560         thresholds=[0, 0.5, 1.0],
   1561         updates_collections=[my_collection_name])
   1562     _, recall_op = metrics.recall_at_thresholds(
   1563         predictions=array_ops.ones((10, 1)),
   1564         labels=array_ops.ones((10, 1)),
   1565         thresholds=[0, 0.5, 1.0],
   1566         updates_collections=[my_collection_name])
   1567     self.assertListEqual(
   1568         ops.get_collection(my_collection_name), [precision_op, recall_op])
   1569 
   1570   def testValueTensorIsIdempotent(self):
   1571     predictions = random_ops.random_uniform(
   1572         (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1)
   1573     labels = random_ops.random_uniform(
   1574         (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1)
   1575     thresholds = [0, 0.5, 1.0]
   1576     prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
   1577                                                     thresholds)
   1578     rec, rec_op = metrics.recall_at_thresholds(labels, predictions, thresholds)
   1579 
   1580     with self.test_session() as sess:
   1581       sess.run(variables.local_variables_initializer())
   1582 
   1583       # Run several updates, then verify idempotency.
   1584       sess.run([prec_op, rec_op])
   1585       initial_prec = prec.eval()
   1586       initial_rec = rec.eval()
   1587       for _ in range(10):
   1588         sess.run([prec_op, rec_op])
   1589         self.assertAllClose(initial_prec, prec.eval())
   1590         self.assertAllClose(initial_rec, rec.eval())
   1591 
   1592   # TODO(nsilberman): fix tests (passing but incorrect).
   1593   def testAllCorrect(self):
   1594     inputs = np.random.randint(0, 2, size=(100, 1))
   1595 
   1596     with self.test_session() as sess:
   1597       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
   1598       labels = constant_op.constant(inputs)
   1599       thresholds = [0.5]
   1600       prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
   1601                                                       thresholds)
   1602       rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
   1603                                                  thresholds)
   1604 
   1605       sess.run(variables.local_variables_initializer())
   1606       sess.run([prec_op, rec_op])
   1607 
   1608       self.assertEqual(1, prec.eval())
   1609       self.assertEqual(1, rec.eval())
   1610 
   1611   def testSomeCorrect_multipleLabelDtypes(self):
   1612     with self.test_session() as sess:
   1613       for label_dtype in (
   1614           dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32):
   1615         predictions = constant_op.constant(
   1616             [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
   1617         labels = math_ops.cast(
   1618             constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype)
   1619         thresholds = [0.5]
   1620         prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
   1621                                                         thresholds)
   1622         rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
   1623                                                    thresholds)
   1624 
   1625         sess.run(variables.local_variables_initializer())
   1626         sess.run([prec_op, rec_op])
   1627 
   1628         self.assertAlmostEqual(0.5, prec.eval())
   1629         self.assertAlmostEqual(0.5, rec.eval())
   1630 
   1631   def testAllIncorrect(self):
   1632     inputs = np.random.randint(0, 2, size=(100, 1))
   1633 
   1634     with self.test_session() as sess:
   1635       predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32)
   1636       labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32)
   1637       thresholds = [0.5]
   1638       prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
   1639                                                       thresholds)
   1640       rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
   1641                                                  thresholds)
   1642 
   1643       sess.run(variables.local_variables_initializer())
   1644       sess.run([prec_op, rec_op])
   1645 
   1646       self.assertAlmostEqual(0, prec.eval())
   1647       self.assertAlmostEqual(0, rec.eval())
   1648 
   1649   def testWeights1d(self):
   1650     with self.test_session() as sess:
   1651       predictions = constant_op.constant(
   1652           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
   1653       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
   1654       weights = constant_op.constant(
   1655           [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32)
   1656       thresholds = [0.5, 1.1]
   1657       prec, prec_op = metrics.precision_at_thresholds(
   1658           labels, predictions, thresholds, weights=weights)
   1659       rec, rec_op = metrics.recall_at_thresholds(
   1660           labels, predictions, thresholds, weights=weights)
   1661 
   1662       [prec_low, prec_high] = array_ops.split(
   1663           value=prec, num_or_size_splits=2, axis=0)
   1664       prec_low = array_ops.reshape(prec_low, shape=())
   1665       prec_high = array_ops.reshape(prec_high, shape=())
   1666       [rec_low, rec_high] = array_ops.split(
   1667           value=rec, num_or_size_splits=2, axis=0)
   1668       rec_low = array_ops.reshape(rec_low, shape=())
   1669       rec_high = array_ops.reshape(rec_high, shape=())
   1670 
   1671       sess.run(variables.local_variables_initializer())
   1672       sess.run([prec_op, rec_op])
   1673 
   1674       self.assertAlmostEqual(1.0, prec_low.eval(), places=5)
   1675       self.assertAlmostEqual(0.0, prec_high.eval(), places=5)
   1676       self.assertAlmostEqual(1.0, rec_low.eval(), places=5)
   1677       self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
   1678 
   1679   def testWeights2d(self):
   1680     with self.test_session() as sess:
   1681       predictions = constant_op.constant(
   1682           [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32)
   1683       labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2))
   1684       weights = constant_op.constant(
   1685           [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32)
   1686       thresholds = [0.5, 1.1]
   1687       prec, prec_op = metrics.precision_at_thresholds(
   1688           labels, predictions, thresholds, weights=weights)
   1689       rec, rec_op = metrics.recall_at_thresholds(
   1690           labels, predictions, thresholds, weights=weights)
   1691 
   1692       [prec_low, prec_high] = array_ops.split(
   1693           value=prec, num_or_size_splits=2, axis=0)
   1694       prec_low = array_ops.reshape(prec_low, shape=())
   1695       prec_high = array_ops.reshape(prec_high, shape=())
   1696       [rec_low, rec_high] = array_ops.split(
   1697           value=rec, num_or_size_splits=2, axis=0)
   1698       rec_low = array_ops.reshape(rec_low, shape=())
   1699       rec_high = array_ops.reshape(rec_high, shape=())
   1700 
   1701       sess.run(variables.local_variables_initializer())
   1702       sess.run([prec_op, rec_op])
   1703 
   1704       self.assertAlmostEqual(1.0, prec_low.eval(), places=5)
   1705       self.assertAlmostEqual(0.0, prec_high.eval(), places=5)
   1706       self.assertAlmostEqual(1.0, rec_low.eval(), places=5)
   1707       self.assertAlmostEqual(0.0, rec_high.eval(), places=5)
   1708 
   1709   def testExtremeThresholds(self):
   1710     with self.test_session() as sess:
   1711       predictions = constant_op.constant(
   1712           [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32)
   1713       labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4))
   1714       thresholds = [-1.0, 2.0]  # lower/higher than any values
   1715       prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
   1716                                                       thresholds)
   1717       rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
   1718                                                  thresholds)
   1719 
   1720       [prec_low, prec_high] = array_ops.split(
   1721           value=prec, num_or_size_splits=2, axis=0)
   1722       [rec_low, rec_high] = array_ops.split(
   1723           value=rec, num_or_size_splits=2, axis=0)
   1724 
   1725       sess.run(variables.local_variables_initializer())
   1726       sess.run([prec_op, rec_op])
   1727 
   1728       self.assertAlmostEqual(0.75, prec_low.eval())
   1729       self.assertAlmostEqual(0.0, prec_high.eval())
   1730       self.assertAlmostEqual(1.0, rec_low.eval())
   1731       self.assertAlmostEqual(0.0, rec_high.eval())
   1732 
   1733   def testZeroLabelsPredictions(self):
   1734     with self.test_session() as sess:
   1735       predictions = array_ops.zeros([4], dtype=dtypes_lib.float32)
   1736       labels = array_ops.zeros([4])
   1737       thresholds = [0.5]
   1738       prec, prec_op = metrics.precision_at_thresholds(labels, predictions,
   1739                                                       thresholds)
   1740       rec, rec_op = metrics.recall_at_thresholds(labels, predictions,
   1741                                                  thresholds)
   1742 
   1743       sess.run(variables.local_variables_initializer())
   1744       sess.run([prec_op, rec_op])
   1745 
   1746       self.assertAlmostEqual(0, prec.eval(), 6)
   1747       self.assertAlmostEqual(0, rec.eval(), 6)
   1748 
   1749   def testWithMultipleUpdates(self):
   1750     num_samples = 1000
   1751     batch_size = 10
   1752     num_batches = int(num_samples / batch_size)
   1753 
   1754     # Create the labels and data.
   1755     labels = np.random.randint(0, 2, size=(num_samples, 1))
   1756     noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1))
   1757     predictions = 0.4 + 0.2 * labels + noise
   1758     predictions[predictions > 1] = 1
   1759     predictions[predictions < 0] = 0
   1760     thresholds = [0.3]
   1761 
   1762     tp = 0
   1763     fp = 0
   1764     fn = 0
   1765     tn = 0
   1766     for i in range(num_samples):
   1767       if predictions[i] > thresholds[0]:
   1768         if labels[i] == 1:
   1769           tp += 1
   1770         else:
   1771           fp += 1
   1772       else:
   1773         if labels[i] == 1:
   1774           fn += 1
   1775         else:
   1776           tn += 1
   1777     epsilon = 1e-7
   1778     expected_prec = tp / (epsilon + tp + fp)
   1779     expected_rec = tp / (epsilon + tp + fn)
   1780 
   1781     labels = labels.astype(np.float32)
   1782     predictions = predictions.astype(np.float32)
   1783 
   1784     with self.test_session() as sess:
   1785       # Reshape the data so its easy to queue up:
   1786       predictions_batches = predictions.reshape((batch_size, num_batches))
   1787       labels_batches = labels.reshape((batch_size, num_batches))
   1788 
   1789       # Enqueue the data:
   1790       predictions_queue = data_flow_ops.FIFOQueue(
   1791           num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
   1792       labels_queue = data_flow_ops.FIFOQueue(
   1793           num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,))
   1794 
   1795       for i in range(int(num_batches)):
   1796         tf_prediction = constant_op.constant(predictions_batches[:, i])
   1797         tf_label = constant_op.constant(labels_batches[:, i])
   1798         sess.run([
   1799             predictions_queue.enqueue(tf_prediction),
   1800             labels_queue.enqueue(tf_label)
   1801         ])
   1802 
   1803       tf_predictions = predictions_queue.dequeue()
   1804       tf_labels = labels_queue.dequeue()
   1805 
   1806       prec, prec_op = metrics.precision_at_thresholds(tf_labels, tf_predictions,
   1807                                                       thresholds)
   1808       rec, rec_op = metrics.recall_at_thresholds(tf_labels, tf_predictions,
   1809                                                  thresholds)
   1810 
   1811       sess.run(variables.local_variables_initializer())
   1812       for _ in range(int(num_samples / batch_size)):
   1813         sess.run([prec_op, rec_op])
   1814       # Since this is only approximate, we can't expect a 6 digits match.
   1815       # Although with higher number of samples/thresholds we should see the
   1816       # accuracy improving
   1817       self.assertAlmostEqual(expected_prec, prec.eval(), 2)
   1818       self.assertAlmostEqual(expected_rec, rec.eval(), 2)
   1819 
   1820 
   1821 def _test_precision_at_k(predictions,
   1822                          labels,
   1823                          k,
   1824                          expected,
   1825                          class_id=None,
   1826                          weights=None,
   1827                          test_case=None):
   1828   with ops.Graph().as_default() as g, test_case.test_session(g):
   1829     if weights is not None:
   1830       weights = constant_op.constant(weights, dtypes_lib.float32)
   1831     metric, update = metrics.precision_at_k(
   1832         predictions=constant_op.constant(predictions, dtypes_lib.float32),
   1833         labels=labels,
   1834         k=k,
   1835         class_id=class_id,
   1836         weights=weights)
   1837 
   1838     # Fails without initialized vars.
   1839     test_case.assertRaises(errors_impl.OpError, metric.eval)
   1840     test_case.assertRaises(errors_impl.OpError, update.eval)
   1841     variables.variables_initializer(variables.local_variables()).run()
   1842 
   1843     # Run per-step op and assert expected values.
   1844     if math.isnan(expected):
   1845       _assert_nan(test_case, update.eval())
   1846       _assert_nan(test_case, metric.eval())
   1847     else:
   1848       test_case.assertEqual(expected, update.eval())
   1849       test_case.assertEqual(expected, metric.eval())
   1850 
   1851 
   1852 def _test_precision_at_top_k(
   1853     predictions_idx,
   1854     labels,
   1855     expected,
   1856     k=None,
   1857     class_id=None,
   1858     weights=None,
   1859     test_case=None):
   1860   with ops.Graph().as_default() as g, test_case.test_session(g):
   1861     if weights is not None:
   1862       weights = constant_op.constant(weights, dtypes_lib.float32)
   1863     metric, update = metrics.precision_at_top_k(
   1864         predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32),
   1865         labels=labels,
   1866         k=k,
   1867         class_id=class_id,
   1868         weights=weights)
   1869 
   1870     # Fails without initialized vars.
   1871     test_case.assertRaises(errors_impl.OpError, metric.eval)
   1872     test_case.assertRaises(errors_impl.OpError, update.eval)
   1873     variables.variables_initializer(variables.local_variables()).run()
   1874 
   1875     # Run per-step op and assert expected values.
   1876     if math.isnan(expected):
   1877       test_case.assertTrue(math.isnan(update.eval()))
   1878       test_case.assertTrue(math.isnan(metric.eval()))
   1879     else:
   1880       test_case.assertEqual(expected, update.eval())
   1881       test_case.assertEqual(expected, metric.eval())
   1882 
   1883 
   1884 def _test_average_precision_at_k(predictions,
   1885                                  labels,
   1886                                  k,
   1887                                  expected,
   1888                                  weights=None,
   1889                                  test_case=None):
   1890   with ops.Graph().as_default() as g, test_case.test_session(g):
   1891     if weights is not None:
   1892       weights = constant_op.constant(weights, dtypes_lib.float32)
   1893     predictions = constant_op.constant(predictions, dtypes_lib.float32)
   1894     metric, update = metrics.average_precision_at_k(
   1895         labels, predictions, k, weights=weights)
   1896 
   1897     # Fails without initialized vars.
   1898     test_case.assertRaises(errors_impl.OpError, metric.eval)
   1899     test_case.assertRaises(errors_impl.OpError, update.eval)
   1900     variables.variables_initializer(variables.local_variables()).run()
   1901 
   1902     # Run per-step op and assert expected values.
   1903     if math.isnan(expected):
   1904       _assert_nan(test_case, update.eval())
   1905       _assert_nan(test_case, metric.eval())
   1906     else:
   1907       test_case.assertAlmostEqual(expected, update.eval())
   1908       test_case.assertAlmostEqual(expected, metric.eval())
   1909 
   1910 
   1911 class SingleLabelPrecisionAtKTest(test.TestCase):
   1912 
   1913   def setUp(self):
   1914     self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4))
   1915     self._predictions_idx = [[3], [3]]
   1916     indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0))
   1917     class_labels = (3, 2)
   1918     # Sparse vs dense, and 1d vs 2d labels should all be handled the same.
   1919     self._labels = (
   1920         _binary_2d_label_to_1d_sparse_value(indicator_labels),
   1921         _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array(
   1922             class_labels, dtype=np.int64), np.array(
   1923                 [[class_id] for class_id in class_labels], dtype=np.int64))
   1924     self._test_precision_at_k = functools.partial(
   1925         _test_precision_at_k, test_case=self)
   1926     self._test_precision_at_top_k = functools.partial(
   1927         _test_precision_at_top_k, test_case=self)
   1928     self._test_average_precision_at_k = functools.partial(
   1929         _test_average_precision_at_k, test_case=self)
   1930 
   1931   def test_at_k1_nan(self):
   1932     for labels in self._labels:
   1933       # Classes 0,1,2 have 0 predictions, classes -1 and 4 are out of range.
   1934       for class_id in (-1, 0, 1, 2, 4):
   1935         self._test_precision_at_k(
   1936             self._predictions, labels, k=1, expected=NAN, class_id=class_id)
   1937         self._test_precision_at_top_k(
   1938             self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id)
   1939 
   1940   def test_at_k1(self):
   1941     for labels in self._labels:
   1942       # Class 3: 1 label, 2 predictions, 1 correct.
   1943       self._test_precision_at_k(
   1944           self._predictions, labels, k=1, expected=1.0 / 2, class_id=3)
   1945       self._test_precision_at_top_k(
   1946           self._predictions_idx, labels, k=1, expected=1.0 / 2, class_id=3)
   1947 
   1948       # All classes: 2 labels, 2 predictions, 1 correct.
   1949       self._test_precision_at_k(
   1950           self._predictions, labels, k=1, expected=1.0 / 2)
   1951       self._test_precision_at_top_k(
   1952           self._predictions_idx, labels, k=1, expected=1.0 / 2)
   1953       self._test_average_precision_at_k(
   1954           self._predictions, labels, k=1, expected=1.0 / 2)
   1955 
   1956 
   1957 class MultiLabelPrecisionAtKTest(test.TestCase):
   1958 
   1959   def setUp(self):
   1960     self._test_precision_at_k = functools.partial(
   1961         _test_precision_at_k, test_case=self)
   1962     self._test_precision_at_top_k = functools.partial(
   1963         _test_precision_at_top_k, test_case=self)
   1964     self._test_average_precision_at_k = functools.partial(
   1965         _test_average_precision_at_k, test_case=self)
   1966 
   1967   def test_average_precision(self):
   1968     # Example 1.
   1969     # Matches example here:
   1970     # fastml.com/what-you-wanted-to-know-about-mean-average-precision
   1971     labels_ex1 = (0, 1, 2, 3, 4)
   1972     labels = np.array([labels_ex1], dtype=np.int64)
   1973     predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3)
   1974     predictions = (predictions_ex1,)
   1975     predictions_idx_ex1 = (5, 3, 6, 0, 1)
   1976     precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4)
   1977     avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3,
   1978                          (precision_ex1[1] + precision_ex1[3]) / 4)
   1979     for i in xrange(4):
   1980       k = i + 1
   1981       self._test_precision_at_k(
   1982           predictions, labels, k, expected=precision_ex1[i])
   1983       self._test_precision_at_top_k(
   1984           (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i])
   1985       self._test_average_precision_at_k(
   1986           predictions, labels, k, expected=avg_precision_ex1[i])
   1987 
   1988     # Example 2.
   1989     labels_ex2 = (0, 2, 4, 5, 6)
   1990     labels = np.array([labels_ex2], dtype=np.int64)
   1991     predictions_ex2 = (0.3, 0.5, 0.0, 0.4, 0.0, 0.1, 0.2)
   1992     predictions = (predictions_ex2,)
   1993     predictions_idx_ex2 = (1, 3, 0, 6, 5)
   1994     precision_ex2 = (0.0 / 1, 0.0 / 2, 1.0 / 3, 2.0 / 4)
   1995     avg_precision_ex2 = (0.0 / 1, 0.0 / 2, precision_ex2[2] / 3,
   1996                          (precision_ex2[2] + precision_ex2[3]) / 4)
   1997     for i in xrange(4):
   1998       k = i + 1
   1999       self._test_precision_at_k(
   2000           predictions, labels, k, expected=precision_ex2[i])
   2001       self._test_precision_at_top_k(
   2002           (predictions_idx_ex2[:k],), labels, k=k, expected=precision_ex2[i])
   2003       self._test_average_precision_at_k(
   2004           predictions, labels, k, expected=avg_precision_ex2[i])
   2005 
   2006     # Both examples, we expect both precision and average precision to be the
   2007     # average of the 2 examples.
   2008     labels = np.array([labels_ex1, labels_ex2], dtype=np.int64)
   2009     predictions = (predictions_ex1, predictions_ex2)
   2010     streaming_precision = [(ex1 + ex2) / 2
   2011                            for ex1, ex2 in zip(precision_ex1, precision_ex2)]
   2012     streaming_average_precision = [
   2013         (ex1 + ex2) / 2
   2014         for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
   2015     ]
   2016     for i in xrange(4):
   2017       k = i + 1
   2018       predictions_idx = (predictions_idx_ex1[:k], predictions_idx_ex2[:k])
   2019       self._test_precision_at_k(
   2020           predictions, labels, k, expected=streaming_precision[i])
   2021       self._test_precision_at_top_k(
   2022           predictions_idx, labels, k=k, expected=streaming_precision[i])
   2023       self._test_average_precision_at_k(
   2024           predictions, labels, k, expected=streaming_average_precision[i])
   2025 
   2026     # Weighted examples, we expect streaming average precision to be the
   2027     # weighted average of the 2 examples.
   2028     weights = (0.3, 0.6)
   2029     streaming_average_precision = [
   2030         (weights[0] * ex1 + weights[1] * ex2) / (weights[0] + weights[1])
   2031         for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2)
   2032     ]
   2033     for i in xrange(4):
   2034       k = i + 1
   2035       self._test_average_precision_at_k(
   2036           predictions,
   2037           labels,
   2038           k,
   2039           expected=streaming_average_precision[i],
   2040           weights=weights)
   2041 
   2042   def test_average_precision_some_labels_out_of_range(self):
   2043     """Tests that labels outside the [0, n_classes) range are ignored."""
   2044     labels_ex1 = (-1, 0, 1, 2, 3, 4, 7)
   2045     labels = np.array([labels_ex1], dtype=np.int64)
   2046     predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3)
   2047     predictions = (predictions_ex1,)
   2048     predictions_idx_ex1 = (5, 3, 6, 0, 1)
   2049     precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4)
   2050     avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3,
   2051                          (precision_ex1[1] + precision_ex1[3]) / 4)
   2052     for i in xrange(4):
   2053       k = i + 1
   2054       self._test_precision_at_k(
   2055           predictions, labels, k, expected=precision_ex1[i])
   2056       self._test_precision_at_top_k(
   2057           (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i])
   2058       self._test_average_precision_at_k(
   2059           predictions, labels, k, expected=avg_precision_ex1[i])
   2060 
   2061   def test_three_labels_at_k5_no_predictions(self):
   2062     predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
   2063                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
   2064     predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
   2065     sparse_labels = _binary_2d_label_to_2d_sparse_value(
   2066         [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
   2067     dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
   2068 
   2069     for labels in (sparse_labels, dense_labels):
   2070       # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range.
   2071       for class_id in (-1, 1, 3, 8, 10):
   2072         self._test_precision_at_k(
   2073             predictions, labels, k=5, expected=NAN, class_id=class_id)
   2074         self._test_precision_at_top_k(
   2075             predictions_idx, labels, k=5, expected=NAN, class_id=class_id)
   2076 
   2077   def test_three_labels_at_k5_no_labels(self):
   2078     predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
   2079                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
   2080     predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
   2081     sparse_labels = _binary_2d_label_to_2d_sparse_value(
   2082         [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
   2083     dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
   2084 
   2085     for labels in (sparse_labels, dense_labels):
   2086       # Classes 0,4,6,9: 0 labels, >=1 prediction.
   2087       for class_id in (0, 4, 6, 9):
   2088         self._test_precision_at_k(
   2089             predictions, labels, k=5, expected=0.0, class_id=class_id)
   2090         self._test_precision_at_top_k(
   2091             predictions_idx, labels, k=5, expected=0.0, class_id=class_id)
   2092 
   2093   def test_three_labels_at_k5(self):
   2094     predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
   2095                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
   2096     predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
   2097     sparse_labels = _binary_2d_label_to_2d_sparse_value(
   2098         [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]])
   2099     dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64)
   2100 
   2101     for labels in (sparse_labels, dense_labels):
   2102       # Class 2: 2 labels, 2 correct predictions.
   2103       self._test_precision_at_k(
   2104           predictions, labels, k=5, expected=2.0 / 2, class_id=2)
   2105       self._test_precision_at_top_k(
   2106           predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)
   2107 
   2108       # Class 5: 1 label, 1 correct prediction.
   2109       self._test_precision_at_k(
   2110           predictions, labels, k=5, expected=1.0 / 1, class_id=5)
   2111       self._test_precision_at_top_k(
   2112           predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)
   2113 
   2114       # Class 7: 1 label, 1 incorrect prediction.
   2115       self._test_precision_at_k(
   2116           predictions, labels, k=5, expected=0.0 / 1, class_id=7)
   2117       self._test_precision_at_top_k(
   2118           predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)
   2119 
   2120       # All classes: 10 predictions, 3 correct.
   2121       self._test_precision_at_k(
   2122           predictions, labels, k=5, expected=3.0 / 10)
   2123       self._test_precision_at_top_k(
   2124           predictions_idx, labels, k=5, expected=3.0 / 10)
   2125 
   2126   def test_three_labels_at_k5_some_out_of_range(self):
   2127     """Tests that labels outside the [0, n_classes) range are ignored."""
   2128     predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
   2129                    [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]]
   2130     predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]]
   2131     sp_labels = sparse_tensor.SparseTensorValue(
   2132         indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
   2133                  [1, 3]],
   2134         # values -1 and 10 are outside the [0, n_classes) range and are ignored.
   2135         values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
   2136         dense_shape=[2, 4])
   2137 
   2138     # Class 2: 2 labels, 2 correct predictions.
   2139     self._test_precision_at_k(
   2140         predictions, sp_labels, k=5, expected=2.0 / 2, class_id=2)
   2141     self._test_precision_at_top_k(
   2142         predictions_idx, sp_labels, k=5, expected=2.0 / 2, class_id=2)
   2143 
   2144     # Class 5: 1 label, 1 correct prediction.
   2145     self._test_precision_at_k(
   2146         predictions, sp_labels, k=5, expected=1.0 / 1, class_id=5)
   2147     self._test_precision_at_top_k(
   2148         predictions_idx, sp_labels, k=5, expected=1.0 / 1, class_id=5)
   2149 
   2150     # Class 7: 1 label, 1 incorrect prediction.
   2151     self._test_precision_at_k(
   2152         predictions, sp_labels, k=5, expected=0.0 / 1, class_id=7)
   2153     self._test_precision_at_top_k(
   2154         predictions_idx, sp_labels, k=5, expected=0.0 / 1, class_id=7)
   2155 
   2156     # All classes: 10 predictions, 3 correct.
   2157     self._test_precision_at_k(
   2158         predictions, sp_labels, k=5, expected=3.0 / 10)
   2159     self._test_precision_at_top_k(
   2160         predictions_idx, sp_labels, k=5, expected=3.0 / 10)
   2161 
   2162   def test_3d_nan(self):
   2163     predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
   2164                     [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
   2165                    [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
   2166                     [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
   2167     predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
   2168                        [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
   2169     labels = _binary_3d_label_to_sparse_value(
   2170         [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
   2171          [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
   2172 
   2173     # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range.
   2174     for class_id in (-1, 1, 3, 8, 10):
   2175       self._test_precision_at_k(
   2176           predictions, labels, k=5, expected=NAN, class_id=class_id)
   2177       self._test_precision_at_top_k(
   2178           predictions_idx, labels, k=5, expected=NAN, class_id=class_id)
   2179 
   2180   def test_3d_no_labels(self):
   2181     predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
   2182                     [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
   2183                    [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
   2184                     [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
   2185     predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
   2186                        [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
   2187     labels = _binary_3d_label_to_sparse_value(
   2188         [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
   2189          [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
   2190 
   2191     # Classes 0,4,6,9: 0 labels, >=1 prediction.
   2192     for class_id in (0, 4, 6, 9):
   2193       self._test_precision_at_k(
   2194           predictions, labels, k=5, expected=0.0, class_id=class_id)
   2195       self._test_precision_at_top_k(
   2196           predictions_idx, labels, k=5, expected=0.0, class_id=class_id)
   2197 
   2198   def test_3d(self):
   2199     predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
   2200                     [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
   2201                    [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
   2202                     [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
   2203     predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
   2204                        [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
   2205     labels = _binary_3d_label_to_sparse_value(
   2206         [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
   2207          [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
   2208 
   2209     # Class 2: 4 predictions, all correct.
   2210     self._test_precision_at_k(
   2211         predictions, labels, k=5, expected=4.0 / 4, class_id=2)
   2212     self._test_precision_at_top_k(
   2213         predictions_idx, labels, k=5, expected=4.0 / 4, class_id=2)
   2214 
   2215     # Class 5: 2 predictions, both correct.
   2216     self._test_precision_at_k(
   2217         predictions, labels, k=5, expected=2.0 / 2, class_id=5)
   2218     self._test_precision_at_top_k(
   2219         predictions_idx, labels, k=5, expected=2.0 / 2, class_id=5)
   2220 
   2221     # Class 7: 2 predictions, 1 correct.
   2222     self._test_precision_at_k(
   2223         predictions, labels, k=5, expected=1.0 / 2, class_id=7)
   2224     self._test_precision_at_top_k(
   2225         predictions_idx, labels, k=5, expected=1.0 / 2, class_id=7)
   2226 
   2227     # All classes: 20 predictions, 7 correct.
   2228     self._test_precision_at_k(
   2229         predictions, labels, k=5, expected=7.0 / 20)
   2230     self._test_precision_at_top_k(
   2231         predictions_idx, labels, k=5, expected=7.0 / 20)
   2232 
   2233   def test_3d_ignore_some(self):
   2234     predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9],
   2235                     [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]],
   2236                    [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6],
   2237                     [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]]
   2238     predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]],
   2239                        [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]]
   2240     labels = _binary_3d_label_to_sparse_value(
   2241         [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]],
   2242          [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]])
   2243 
   2244     # Class 2: 2 predictions, both correct.
   2245     self._test_precision_at_k(
   2246         predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
   2247         weights=[[1], [0]])
   2248     self._test_precision_at_top_k(
   2249         predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2,
   2250         weights=[[1], [0]])
   2251 
   2252     # Class 2: 2 predictions, both correct.
   2253     self._test_precision_at_k(
   2254         predictions, labels, k=5, expected=2.0 / 2.0, class_id=2,
   2255         weights=[[0], [1]])
   2256     self._test_precision_at_top_k(
   2257         predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2,
   2258         weights=[[0], [1]])
   2259 
   2260     # Class 7: 1 incorrect prediction.
   2261     self._test_precision_at_k(
   2262         predictions, labels, k=5, expected=0.0 / 1.0, class_id=7,
   2263         weights=[[1], [0]])
   2264     self._test_precision_at_top_k(
   2265         predictions_idx, labels, k=5, expected=0.0 / 1.0, class_id=7,
   2266         weights=[[1], [0]])
   2267 
   2268     # Class 7: 1 correct prediction.
   2269     self._test_precision_at_k(
   2270         predictions, labels, k=5, expected=1.0 / 1.0, class_id=7,
   2271         weights=[[0], [1]])
   2272     self._test_precision_at_top_k(
   2273         predictions_idx, labels, k=5, expected=1.0 / 1.0, class_id=7,
   2274         weights=[[0], [1]])
   2275 
   2276     # Class 7: no predictions.
   2277     self._test_precision_at_k(
   2278         predictions, labels, k=5, expected=NAN, class_id=7,
   2279         weights=[[1, 0], [0, 1]])
   2280     self._test_precision_at_top_k(
   2281         predictions_idx, labels, k=5, expected=NAN, class_id=7,
   2282         weights=[[1, 0], [0, 1]])
   2283 
   2284     # Class 7: 2 predictions, 1 correct.
   2285     self._test_precision_at_k(
   2286         predictions, labels, k=5, expected=1.0 / 2.0, class_id=7,
   2287         weights=[[0, 1], [1, 0]])
   2288     self._test_precision_at_top_k(
   2289         predictions_idx, labels, k=5, expected=1.0 / 2.0, class_id=7,
   2290         weights=[[0, 1], [1, 0]])
   2291 
   2292 
   2293 def _test_recall_at_k(predictions,
   2294                       labels,
   2295                       k,
   2296                       expected,
   2297                       class_id=None,
   2298                       weights=None,
   2299                       test_case=None):
   2300   with ops.Graph().as_default() as g, test_case.test_session(g):
   2301     if weights is not None:
   2302       weights = constant_op.constant(weights, dtypes_lib.float32)
   2303     metric, update = metrics.recall_at_k(
   2304         predictions=constant_op.constant(predictions, dtypes_lib.float32),
   2305         labels=labels,
   2306         k=k,
   2307         class_id=class_id,
   2308         weights=weights)
   2309 
   2310     # Fails without initialized vars.
   2311     test_case.assertRaises(errors_impl.OpError, metric.eval)
   2312     test_case.assertRaises(errors_impl.OpError, update.eval)
   2313     variables.variables_initializer(variables.local_variables()).run()
   2314 
   2315     # Run per-step op and assert expected values.
   2316     if math.isnan(expected):
   2317       _assert_nan(test_case, update.eval())
   2318       _assert_nan(test_case, metric.eval())
   2319     else:
   2320       test_case.assertEqual(expected, update.eval())
   2321       test_case.assertEqual(expected, metric.eval())
   2322 
   2323 
   2324 def _test_recall_at_top_k(
   2325     predictions_idx,
   2326     labels,
   2327     expected,
   2328     k=None,
   2329     class_id=None,
   2330     weights=None,
   2331     test_case=None):
   2332   with ops.Graph().as_default() as g, test_case.test_session(g):
   2333     if weights is not None:
   2334       weights = constant_op.constant(weights, dtypes_lib.float32)
   2335     metric, update = metrics.recall_at_top_k(
   2336         predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32),
   2337         labels=labels,
   2338         k=k,
   2339         class_id=class_id,
   2340         weights=weights)
   2341 
   2342     # Fails without initialized vars.
   2343     test_case.assertRaises(errors_impl.OpError, metric.eval)
   2344     test_case.assertRaises(errors_impl.OpError, update.eval)
   2345     variables.variables_initializer(variables.local_variables()).run()
   2346 
   2347     # Run per-step op and assert expected values.
   2348     if math.isnan(expected):
   2349       _assert_nan(test_case, update.eval())
   2350       _assert_nan(test_case, metric.eval())
   2351     else:
   2352       test_case.assertEqual(expected, update.eval())
   2353       test_case.assertEqual(expected, metric.eval())
   2354 
   2355 
   2356 class SingleLabelRecallAtKTest(test.TestCase):
   2357 
   2358   def setUp(self):
   2359     self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4))
   2360     self._predictions_idx = [[3], [3]]
   2361     indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0))
   2362     class_labels = (3, 2)
   2363     # Sparse vs dense, and 1d vs 2d labels should all be handled the same.
   2364     self._labels = (
   2365         _binary_2d_label_to_1d_sparse_value(indicator_labels),
   2366         _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array(
   2367             class_labels, dtype=np.int64), np.array(
   2368                 [[class_id] for class_id in class_labels], dtype=np.int64))
   2369     self._test_recall_at_k = functools.partial(
   2370         _test_recall_at_k, test_case=self)
   2371     self._test_recall_at_top_k = functools.partial(
   2372         _test_recall_at_top_k, test_case=self)
   2373 
   2374   def test_at_k1_nan(self):
   2375     # Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of
   2376     # range.
   2377     for labels in self._labels:
   2378       for class_id in (-1, 0, 1, 4):
   2379         self._test_recall_at_k(
   2380             self._predictions, labels, k=1, expected=NAN, class_id=class_id)
   2381         self._test_recall_at_top_k(
   2382             self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id)
   2383 
   2384   def test_at_k1_no_predictions(self):
   2385     for labels in self._labels:
   2386       # Class 2: 0 predictions.
   2387       self._test_recall_at_k(
   2388           self._predictions, labels, k=1, expected=0.0, class_id=2)
   2389       self._test_recall_at_top_k(
   2390           self._predictions_idx, labels, k=1, expected=0.0, class_id=2)
   2391 
   2392   def test_one_label_at_k1(self):
   2393     for labels in self._labels:
   2394       # Class 3: 1 label, 2 predictions, 1 correct.
   2395       self._test_recall_at_k(
   2396           self._predictions, labels, k=1, expected=1.0 / 1, class_id=3)
   2397       self._test_recall_at_top_k(
   2398           self._predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3)
   2399 
   2400       # All classes: 2 labels, 2 predictions, 1 correct.
   2401       self._test_recall_at_k(self._predictions, labels, k=1, expected=1.0 / 2)
   2402       self._test_recall_at_top_k(
   2403           self._predictions_idx, labels, k=1, expected=1.0 / 2)
   2404 
   2405   def test_one_label_at_k1_weighted_class_id3(self):
   2406     predictions = self._predictions
   2407     predictions_idx = self._predictions_idx
   2408     for labels in self._labels:
   2409       # Class 3: 1 label, 2 predictions, 1 correct.
   2410       self._test_recall_at_k(
   2411           predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,))
   2412       self._test_recall_at_top_k(
   2413           predictions_idx, labels, k=1, expected=NAN, class_id=3,
   2414           weights=(0.0,))
   2415       self._test_recall_at_k(
   2416           predictions, labels, k=1, expected=1.0 / 1, class_id=3,
   2417           weights=(1.0,))
   2418       self._test_recall_at_top_k(
   2419           predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
   2420           weights=(1.0,))
   2421       self._test_recall_at_k(
   2422           predictions, labels, k=1, expected=1.0 / 1, class_id=3,
   2423           weights=(2.0,))
   2424       self._test_recall_at_top_k(
   2425           predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
   2426           weights=(2.0,))
   2427       self._test_recall_at_k(
   2428           predictions, labels, k=1, expected=NAN, class_id=3,
   2429           weights=(0.0, 1.0))
   2430       self._test_recall_at_top_k(
   2431           predictions_idx, labels, k=1, expected=NAN, class_id=3,
   2432           weights=(0.0, 1.0))
   2433       self._test_recall_at_k(
   2434           predictions, labels, k=1, expected=1.0 / 1, class_id=3,
   2435           weights=(1.0, 0.0))
   2436       self._test_recall_at_top_k(
   2437           predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3,
   2438           weights=(1.0, 0.0))
   2439       self._test_recall_at_k(
   2440           predictions, labels, k=1, expected=2.0 / 2, class_id=3,
   2441           weights=(2.0, 3.0))
   2442       self._test_recall_at_top_k(
   2443           predictions_idx, labels, k=1, expected=2.0 / 2, class_id=3,
   2444           weights=(2.0, 3.0))
   2445 
   2446   def test_one_label_at_k1_weighted(self):
   2447     predictions = self._predictions
   2448     predictions_idx = self._predictions_idx
   2449     for labels in self._labels:
   2450       # All classes: 2 labels, 2 predictions, 1 correct.
   2451       self._test_recall_at_k(
   2452           predictions, labels, k=1, expected=NAN, weights=(0.0,))
   2453       self._test_recall_at_top_k(
   2454           predictions_idx, labels, k=1, expected=NAN, weights=(0.0,))
   2455       self._test_recall_at_k(
   2456           predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,))
   2457       self._test_recall_at_top_k(
   2458           predictions_idx, labels, k=1, expected=1.0 / 2, weights=(1.0,))
   2459       self._test_recall_at_k(
   2460           predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,))
   2461       self._test_recall_at_top_k(
   2462           predictions_idx, labels, k=1, expected=1.0 / 2, weights=(2.0,))
   2463       self._test_recall_at_k(
   2464           predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
   2465       self._test_recall_at_top_k(
   2466           predictions_idx, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0))
   2467       self._test_recall_at_k(
   2468           predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
   2469       self._test_recall_at_top_k(
   2470           predictions_idx, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0))
   2471       self._test_recall_at_k(
   2472           predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
   2473       self._test_recall_at_top_k(
   2474           predictions_idx, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0))
   2475 
   2476 
   2477 class MultiLabel2dRecallAtKTest(test.TestCase):
   2478 
   2479   def setUp(self):
   2480     self._predictions = ((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9),
   2481                          (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6))
   2482     self._predictions_idx = ((9, 4, 6, 2, 0), (5, 7, 2, 9, 6))
   2483     indicator_labels = ((0, 0, 1, 0, 0, 0, 0, 1, 1, 0),
   2484                         (0, 1, 1, 0, 0, 1, 0, 0, 0, 0))
   2485     class_labels = ((2, 7, 8), (1, 2, 5))
   2486     # Sparse vs dense labels should be handled the same.
   2487     self._labels = (_binary_2d_label_to_2d_sparse_value(indicator_labels),
   2488                     np.array(
   2489                         class_labels, dtype=np.int64))
   2490     self._test_recall_at_k = functools.partial(
   2491         _test_recall_at_k, test_case=self)
   2492     self._test_recall_at_top_k = functools.partial(
   2493         _test_recall_at_top_k, test_case=self)
   2494 
   2495   def test_at_k5_nan(self):
   2496     for labels in self._labels:
   2497       # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range.
   2498       for class_id in (0, 3, 4, 6, 9, 10):
   2499         self._test_recall_at_k(
   2500             self._predictions, labels, k=5, expected=NAN, class_id=class_id)
   2501         self._test_recall_at_top_k(
   2502             self._predictions_idx, labels, k=5, expected=NAN, class_id=class_id)
   2503 
   2504   def test_at_k5_no_predictions(self):
   2505     for labels in self._labels:
   2506       # Class 8: 1 label, no predictions.
   2507       self._test_recall_at_k(
   2508           self._predictions, labels, k=5, expected=0.0 / 1, class_id=8)
   2509       self._test_recall_at_top_k(
   2510           self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=8)
   2511 
   2512   def test_at_k5(self):
   2513     for labels in self._labels:
   2514       # Class 2: 2 labels, both correct.
   2515       self._test_recall_at_k(
   2516           self._predictions, labels, k=5, expected=2.0 / 2, class_id=2)
   2517       self._test_recall_at_top_k(
   2518           self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)
   2519 
   2520       # Class 5: 1 label, incorrect.
   2521       self._test_recall_at_k(
   2522           self._predictions, labels, k=5, expected=1.0 / 1, class_id=5)
   2523       self._test_recall_at_top_k(
   2524           self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)
   2525 
   2526       # Class 7: 1 label, incorrect.
   2527       self._test_recall_at_k(
   2528           self._predictions, labels, k=5, expected=0.0 / 1, class_id=7)
   2529       self._test_recall_at_top_k(
   2530           self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)
   2531 
   2532       # All classes: 6 labels, 3 correct.
   2533       self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 6)
   2534       self._test_recall_at_top_k(
   2535           self._predictions_idx, labels, k=5, expected=3.0 / 6)
   2536 
   2537   def test_at_k5_some_out_of_range(self):
   2538     """Tests that labels outside the [0, n_classes) count in denominator."""
   2539     labels = sparse_tensor.SparseTensorValue(
   2540         indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2],
   2541                  [1, 3]],
   2542         # values -1 and 10 are outside the [0, n_classes) range.
   2543         values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64),
   2544         dense_shape=[2, 4])
   2545 
   2546     # Class 2: 2 labels, both correct.
   2547     self._test_recall_at_k(
   2548         self._predictions, labels, k=5, expected=2.0 / 2, class_id=2)
   2549     self._test_recall_at_top_k(
   2550         self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2)
   2551 
   2552     # Class 5: 1 label, incorrect.
   2553     self._test_recall_at_k(
   2554         self._predictions, labels, k=5, expected=1.0 / 1, class_id=5)
   2555     self._test_recall_at_top_k(
   2556         self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5)
   2557 
   2558     # Class 7: 1 label, incorrect.
   2559     self._test_recall_at_k(
   2560         self._predictions, labels, k=5, expected=0.0 / 1, class_id=7)
   2561     self._test_recall_at_top_k(
   2562         self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7)
   2563 
   2564     # All classes: 8 labels, 3 correct.
   2565     self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 8)
   2566     self._test_recall_at_top_k(
   2567         self._predictions_idx, labels, k=5, expected=3.0 / 8)
   2568 
   2569 
   2570 class MultiLabel3dRecallAtKTest(test.TestCase):
   2571 
   2572   def setUp(self):
   2573     self._predictions = (((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9),
   2574                           (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6)),
   2575                          ((0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6),
   2576                           (0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9)))
   2577     self._predictions_idx = (((9, 4, 6, 2, 0), (5, 7, 2, 9, 6)),
   2578                              ((5, 7, 2, 9, 6), (9, 4, 6, 2, 0)))
   2579     # Note: We don't test dense labels here, since examples have different
   2580     # numbers of labels.
   2581     self._labels = _binary_3d_label_to_sparse_value(((
   2582         (0, 0, 1, 0, 0, 0, 0, 1, 1, 0), (0, 1, 1, 0, 0, 1, 0, 0, 0, 0)), (
   2583             (0, 1, 1, 0, 0, 1, 0, 1, 0, 0), (0, 0, 1, 0, 0, 0, 0, 0, 1, 0))))
   2584     self._test_recall_at_k = functools.partial(
   2585         _test_recall_at_k, test_case=self)
   2586     self._test_recall_at_top_k = functools.partial(
   2587         _test_recall_at_top_k, test_case=self)
   2588 
   2589   def test_3d_nan(self):
   2590     # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range.
   2591     for class_id in (0, 3, 4, 6, 9, 10):
   2592       self._test_recall_at_k(
   2593           self._predictions, self._labels, k=5, expected=NAN, class_id=class_id)
   2594       self._test_recall_at_top_k(
   2595           self._predictions_idx, self._labels, k=5, expected=NAN,
   2596           class_id=class_id)
   2597 
   2598   def test_3d_no_predictions(self):
   2599     # Classes 1,8 have 0 predictions, >=1 label.
   2600     for class_id in (1, 8):
   2601       self._test_recall_at_k(
   2602           self._predictions, self._labels, k=5, expected=0.0, class_id=class_id)
   2603       self._test_recall_at_top_k(
   2604           self._predictions_idx, self._labels, k=5, expected=0.0,
   2605           class_id=class_id)
   2606 
   2607   def test_3d(self):
   2608     # Class 2: 4 labels, all correct.
   2609     self._test_recall_at_k(
   2610         self._predictions, self._labels, k=5, expected=4.0 / 4, class_id=2)
   2611     self._test_recall_at_top_k(
   2612         self._predictions_idx, self._labels, k=5, expected=4.0 / 4,
   2613         class_id=2)
   2614 
   2615     # Class 5: 2 labels, both correct.
   2616     self._test_recall_at_k(
   2617         self._predictions, self._labels, k=5, expected=2.0 / 2, class_id=5)
   2618     self._test_recall_at_top_k(
   2619         self._predictions_idx, self._labels, k=5, expected=2.0 / 2,
   2620         class_id=5)
   2621 
   2622     # Class 7: 2 labels, 1 incorrect.
   2623     self._test_recall_at_k(
   2624         self._predictions, self._labels, k=5, expected=1.0 / 2, class_id=7)
   2625     self._test_recall_at_top_k(
   2626         self._predictions_idx, self._labels, k=5, expected=1.0 / 2,
   2627         class_id=7)
   2628 
   2629     # All classes: 12 labels, 7 correct.
   2630     self._test_recall_at_k(
   2631         self._predictions, self._labels, k=5, expected=7.0 / 12)
   2632     self._test_recall_at_top_k(
   2633         self._predictions_idx, self._labels, k=5, expected=7.0 / 12)
   2634 
   2635   def test_3d_ignore_all(self):
   2636     for class_id in xrange(10):
   2637       self._test_recall_at_k(
   2638           self._predictions, self._labels, k=5, expected=NAN, class_id=class_id,
   2639           weights=[[0], [0]])
   2640       self._test_recall_at_top_k(
   2641           self._predictions_idx, self._labels, k=5, expected=NAN,
   2642           class_id=class_id, weights=[[0], [0]])
   2643       self._test_recall_at_k(
   2644           self._predictions, self._labels, k=5, expected=NAN, class_id=class_id,
   2645           weights=[[0, 0], [0, 0]])
   2646       self._test_recall_at_top_k(
   2647           self._predictions_idx, self._labels, k=5, expected=NAN,
   2648           class_id=class_id, weights=[[0, 0], [0, 0]])
   2649     self._test_recall_at_k(
   2650         self._predictions, self._labels, k=5, expected=NAN, weights=[[0], [0]])
   2651     self._test_recall_at_top_k(
   2652         self._predictions_idx, self._labels, k=5, expected=NAN,
   2653         weights=[[0], [0]])
   2654     self._test_recall_at_k(
   2655         self._predictions, self._labels, k=5, expected=NAN,
   2656         weights=[[0, 0], [0, 0]])
   2657     self._test_recall_at_top_k(
   2658         self._predictions_idx, self._labels, k=5, expected=NAN,
   2659         weights=[[0, 0], [0, 0]])
   2660 
   2661   def test_3d_ignore_some(self):
   2662     # Class 2: 2 labels, both correct.
   2663     self._test_recall_at_k(
   2664         self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2,
   2665         weights=[[1], [0]])
   2666     self._test_recall_at_top_k(
   2667         self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0,
   2668         class_id=2, weights=[[1], [0]])
   2669 
   2670     # Class 2: 2 labels, both correct.
   2671     self._test_recall_at_k(
   2672         self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2,
   2673         weights=[[0], [1]])
   2674     self._test_recall_at_top_k(
   2675         self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0,
   2676         class_id=2, weights=[[0], [1]])
   2677 
   2678     # Class 7: 1 label, correct.
   2679     self._test_recall_at_k(
   2680         self._predictions, self._labels, k=5, expected=1.0 / 1.0, class_id=7,
   2681         weights=[[0], [1]])
   2682     self._test_recall_at_top_k(
   2683         self._predictions_idx, self._labels, k=5, expected=1.0 / 1.0,
   2684         class_id=7, weights=[[0], [1]])
   2685 
   2686     # Class 7: 1 label, incorrect.
   2687     self._test_recall_at_k(
   2688         self._predictions, self._labels, k=5, expected=0.0 / 1.0, class_id=7,
   2689         weights=[[1], [0]])
   2690     self._test_recall_at_top_k(
   2691         self._predictions_idx, self._labels, k=5, expected=0.0 / 1.0,
   2692         class_id=7, weights=[[1], [0]])
   2693 
   2694     # Class 7: 2 labels, 1 correct.
   2695     self._test_recall_at_k(
   2696         self._predictions, self._labels, k=5, expected=1.0 / 2.0, class_id=7,
   2697         weights=[[1, 0], [1, 0]])
   2698     self._test_recall_at_top_k(
   2699         self._predictions_idx, self._labels, k=5, expected=1.0 / 2.0,
   2700         class_id=7, weights=[[1, 0], [1, 0]])
   2701 
   2702     # Class 7: No labels.
   2703     self._test_recall_at_k(
   2704         self._predictions, self._labels, k=5, expected=NAN, class_id=7,
   2705         weights=[[0, 1], [0, 1]])
   2706     self._test_recall_at_top_k(
   2707         self._predictions_idx, self._labels, k=5, expected=NAN, class_id=7,
   2708         weights=[[0, 1], [0, 1]])
   2709 
   2710 
   2711 class MeanAbsoluteErrorTest(test.TestCase):
   2712 
   2713   def setUp(self):
   2714     ops.reset_default_graph()
   2715 
   2716   def testVars(self):
   2717     metrics.mean_absolute_error(
   2718         predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
   2719     _assert_metric_variables(
   2720         self, ('mean_absolute_error/count:0', 'mean_absolute_error/total:0'))
   2721 
   2722   def testMetricsCollection(self):
   2723     my_collection_name = '__metrics__'
   2724     mean, _ = metrics.mean_absolute_error(
   2725         predictions=array_ops.ones((10, 1)),
   2726         labels=array_ops.ones((10, 1)),
   2727         metrics_collections=[my_collection_name])
   2728     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   2729 
   2730   def testUpdatesCollection(self):
   2731     my_collection_name = '__updates__'
   2732     _, update_op = metrics.mean_absolute_error(
   2733         predictions=array_ops.ones((10, 1)),
   2734         labels=array_ops.ones((10, 1)),
   2735         updates_collections=[my_collection_name])
   2736     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   2737 
   2738   def testValueTensorIsIdempotent(self):
   2739     predictions = random_ops.random_normal((10, 3), seed=1)
   2740     labels = random_ops.random_normal((10, 3), seed=2)
   2741     error, update_op = metrics.mean_absolute_error(labels, predictions)
   2742 
   2743     with self.test_session() as sess:
   2744       sess.run(variables.local_variables_initializer())
   2745 
   2746       # Run several updates.
   2747       for _ in range(10):
   2748         sess.run(update_op)
   2749 
   2750       # Then verify idempotency.
   2751       initial_error = error.eval()
   2752       for _ in range(10):
   2753         self.assertEqual(initial_error, error.eval())
   2754 
   2755   def testSingleUpdateWithErrorAndWeights(self):
   2756     predictions = constant_op.constant(
   2757         [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
   2758     labels = constant_op.constant(
   2759         [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
   2760     weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
   2761 
   2762     error, update_op = metrics.mean_absolute_error(labels, predictions, weights)
   2763 
   2764     with self.test_session() as sess:
   2765       sess.run(variables.local_variables_initializer())
   2766       self.assertEqual(3, sess.run(update_op))
   2767       self.assertEqual(3, error.eval())
   2768 
   2769 
   2770 class MeanRelativeErrorTest(test.TestCase):
   2771 
   2772   def setUp(self):
   2773     ops.reset_default_graph()
   2774 
   2775   def testVars(self):
   2776     metrics.mean_relative_error(
   2777         predictions=array_ops.ones((10, 1)),
   2778         labels=array_ops.ones((10, 1)),
   2779         normalizer=array_ops.ones((10, 1)))
   2780     _assert_metric_variables(
   2781         self, ('mean_relative_error/count:0', 'mean_relative_error/total:0'))
   2782 
   2783   def testMetricsCollection(self):
   2784     my_collection_name = '__metrics__'
   2785     mean, _ = metrics.mean_relative_error(
   2786         predictions=array_ops.ones((10, 1)),
   2787         labels=array_ops.ones((10, 1)),
   2788         normalizer=array_ops.ones((10, 1)),
   2789         metrics_collections=[my_collection_name])
   2790     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   2791 
   2792   def testUpdatesCollection(self):
   2793     my_collection_name = '__updates__'
   2794     _, update_op = metrics.mean_relative_error(
   2795         predictions=array_ops.ones((10, 1)),
   2796         labels=array_ops.ones((10, 1)),
   2797         normalizer=array_ops.ones((10, 1)),
   2798         updates_collections=[my_collection_name])
   2799     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   2800 
   2801   def testValueTensorIsIdempotent(self):
   2802     predictions = random_ops.random_normal((10, 3), seed=1)
   2803     labels = random_ops.random_normal((10, 3), seed=2)
   2804     normalizer = random_ops.random_normal((10, 3), seed=3)
   2805     error, update_op = metrics.mean_relative_error(labels, predictions,
   2806                                                    normalizer)
   2807 
   2808     with self.test_session() as sess:
   2809       sess.run(variables.local_variables_initializer())
   2810 
   2811       # Run several updates.
   2812       for _ in range(10):
   2813         sess.run(update_op)
   2814 
   2815       # Then verify idempotency.
   2816       initial_error = error.eval()
   2817       for _ in range(10):
   2818         self.assertEqual(initial_error, error.eval())
   2819 
   2820   def testSingleUpdateNormalizedByLabels(self):
   2821     np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32)
   2822     np_labels = np.asarray([1, 3, 2, 3], dtype=np.float32)
   2823     expected_error = np.mean(
   2824         np.divide(np.absolute(np_predictions - np_labels), np_labels))
   2825 
   2826     predictions = constant_op.constant(
   2827         np_predictions, shape=(1, 4), dtype=dtypes_lib.float32)
   2828     labels = constant_op.constant(np_labels, shape=(1, 4))
   2829 
   2830     error, update_op = metrics.mean_relative_error(
   2831         labels, predictions, normalizer=labels)
   2832 
   2833     with self.test_session() as sess:
   2834       sess.run(variables.local_variables_initializer())
   2835       self.assertEqual(expected_error, sess.run(update_op))
   2836       self.assertEqual(expected_error, error.eval())
   2837 
   2838   def testSingleUpdateNormalizedByZeros(self):
   2839     np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32)
   2840 
   2841     predictions = constant_op.constant(
   2842         np_predictions, shape=(1, 4), dtype=dtypes_lib.float32)
   2843     labels = constant_op.constant(
   2844         [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
   2845 
   2846     error, update_op = metrics.mean_relative_error(
   2847         labels, predictions, normalizer=array_ops.zeros_like(labels))
   2848 
   2849     with self.test_session() as sess:
   2850       sess.run(variables.local_variables_initializer())
   2851       self.assertEqual(0.0, sess.run(update_op))
   2852       self.assertEqual(0.0, error.eval())
   2853 
   2854 
   2855 class MeanSquaredErrorTest(test.TestCase):
   2856 
   2857   def setUp(self):
   2858     ops.reset_default_graph()
   2859 
   2860   def testVars(self):
   2861     metrics.mean_squared_error(
   2862         predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
   2863     _assert_metric_variables(
   2864         self, ('mean_squared_error/count:0', 'mean_squared_error/total:0'))
   2865 
   2866   def testMetricsCollection(self):
   2867     my_collection_name = '__metrics__'
   2868     mean, _ = metrics.mean_squared_error(
   2869         predictions=array_ops.ones((10, 1)),
   2870         labels=array_ops.ones((10, 1)),
   2871         metrics_collections=[my_collection_name])
   2872     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   2873 
   2874   def testUpdatesCollection(self):
   2875     my_collection_name = '__updates__'
   2876     _, update_op = metrics.mean_squared_error(
   2877         predictions=array_ops.ones((10, 1)),
   2878         labels=array_ops.ones((10, 1)),
   2879         updates_collections=[my_collection_name])
   2880     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   2881 
   2882   def testValueTensorIsIdempotent(self):
   2883     predictions = random_ops.random_normal((10, 3), seed=1)
   2884     labels = random_ops.random_normal((10, 3), seed=2)
   2885     error, update_op = metrics.mean_squared_error(labels, predictions)
   2886 
   2887     with self.test_session() as sess:
   2888       sess.run(variables.local_variables_initializer())
   2889 
   2890       # Run several updates.
   2891       for _ in range(10):
   2892         sess.run(update_op)
   2893 
   2894       # Then verify idempotency.
   2895       initial_error = error.eval()
   2896       for _ in range(10):
   2897         self.assertEqual(initial_error, error.eval())
   2898 
   2899   def testSingleUpdateZeroError(self):
   2900     predictions = array_ops.zeros((1, 3), dtype=dtypes_lib.float32)
   2901     labels = array_ops.zeros((1, 3), dtype=dtypes_lib.float32)
   2902 
   2903     error, update_op = metrics.mean_squared_error(labels, predictions)
   2904 
   2905     with self.test_session() as sess:
   2906       sess.run(variables.local_variables_initializer())
   2907       self.assertEqual(0, sess.run(update_op))
   2908       self.assertEqual(0, error.eval())
   2909 
   2910   def testSingleUpdateWithError(self):
   2911     predictions = constant_op.constant(
   2912         [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
   2913     labels = constant_op.constant(
   2914         [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)
   2915 
   2916     error, update_op = metrics.mean_squared_error(labels, predictions)
   2917 
   2918     with self.test_session() as sess:
   2919       sess.run(variables.local_variables_initializer())
   2920       self.assertEqual(6, sess.run(update_op))
   2921       self.assertEqual(6, error.eval())
   2922 
   2923   def testSingleUpdateWithErrorAndWeights(self):
   2924     predictions = constant_op.constant(
   2925         [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
   2926     labels = constant_op.constant(
   2927         [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
   2928     weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
   2929 
   2930     error, update_op = metrics.mean_squared_error(labels, predictions, weights)
   2931 
   2932     with self.test_session() as sess:
   2933       sess.run(variables.local_variables_initializer())
   2934       self.assertEqual(13, sess.run(update_op))
   2935       self.assertEqual(13, error.eval())
   2936 
   2937   def testMultipleBatchesOfSizeOne(self):
   2938     with self.test_session() as sess:
   2939       # Create the queue that populates the predictions.
   2940       preds_queue = data_flow_ops.FIFOQueue(
   2941           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
   2942       _enqueue_vector(sess, preds_queue, [10, 8, 6])
   2943       _enqueue_vector(sess, preds_queue, [-4, 3, -1])
   2944       predictions = preds_queue.dequeue()
   2945 
   2946       # Create the queue that populates the labels.
   2947       labels_queue = data_flow_ops.FIFOQueue(
   2948           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
   2949       _enqueue_vector(sess, labels_queue, [1, 3, 2])
   2950       _enqueue_vector(sess, labels_queue, [2, 4, 6])
   2951       labels = labels_queue.dequeue()
   2952 
   2953       error, update_op = metrics.mean_squared_error(labels, predictions)
   2954 
   2955       sess.run(variables.local_variables_initializer())
   2956       sess.run(update_op)
   2957       self.assertAlmostEqual(208.0 / 6, sess.run(update_op), 5)
   2958 
   2959       self.assertAlmostEqual(208.0 / 6, error.eval(), 5)
   2960 
   2961   def testMetricsComputedConcurrently(self):
   2962     with self.test_session() as sess:
   2963       # Create the queue that populates one set of predictions.
   2964       preds_queue0 = data_flow_ops.FIFOQueue(
   2965           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
   2966       _enqueue_vector(sess, preds_queue0, [10, 8, 6])
   2967       _enqueue_vector(sess, preds_queue0, [-4, 3, -1])
   2968       predictions0 = preds_queue0.dequeue()
   2969 
   2970       # Create the queue that populates one set of predictions.
   2971       preds_queue1 = data_flow_ops.FIFOQueue(
   2972           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
   2973       _enqueue_vector(sess, preds_queue1, [0, 1, 1])
   2974       _enqueue_vector(sess, preds_queue1, [1, 1, 0])
   2975       predictions1 = preds_queue1.dequeue()
   2976 
   2977       # Create the queue that populates one set of labels.
   2978       labels_queue0 = data_flow_ops.FIFOQueue(
   2979           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
   2980       _enqueue_vector(sess, labels_queue0, [1, 3, 2])
   2981       _enqueue_vector(sess, labels_queue0, [2, 4, 6])
   2982       labels0 = labels_queue0.dequeue()
   2983 
   2984       # Create the queue that populates another set of labels.
   2985       labels_queue1 = data_flow_ops.FIFOQueue(
   2986           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
   2987       _enqueue_vector(sess, labels_queue1, [-5, -3, -1])
   2988       _enqueue_vector(sess, labels_queue1, [5, 4, 3])
   2989       labels1 = labels_queue1.dequeue()
   2990 
   2991       mse0, update_op0 = metrics.mean_squared_error(
   2992           labels0, predictions0, name='msd0')
   2993       mse1, update_op1 = metrics.mean_squared_error(
   2994           labels1, predictions1, name='msd1')
   2995 
   2996       sess.run(variables.local_variables_initializer())
   2997       sess.run([update_op0, update_op1])
   2998       sess.run([update_op0, update_op1])
   2999 
   3000       mse0, mse1 = sess.run([mse0, mse1])
   3001       self.assertAlmostEqual(208.0 / 6, mse0, 5)
   3002       self.assertAlmostEqual(79.0 / 6, mse1, 5)
   3003 
   3004   def testMultipleMetricsOnMultipleBatchesOfSizeOne(self):
   3005     with self.test_session() as sess:
   3006       # Create the queue that populates the predictions.
   3007       preds_queue = data_flow_ops.FIFOQueue(
   3008           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
   3009       _enqueue_vector(sess, preds_queue, [10, 8, 6])
   3010       _enqueue_vector(sess, preds_queue, [-4, 3, -1])
   3011       predictions = preds_queue.dequeue()
   3012 
   3013       # Create the queue that populates the labels.
   3014       labels_queue = data_flow_ops.FIFOQueue(
   3015           2, dtypes=dtypes_lib.float32, shapes=(1, 3))
   3016       _enqueue_vector(sess, labels_queue, [1, 3, 2])
   3017       _enqueue_vector(sess, labels_queue, [2, 4, 6])
   3018       labels = labels_queue.dequeue()
   3019 
   3020       mae, ma_update_op = metrics.mean_absolute_error(labels, predictions)
   3021       mse, ms_update_op = metrics.mean_squared_error(labels, predictions)
   3022 
   3023       sess.run(variables.local_variables_initializer())
   3024       sess.run([ma_update_op, ms_update_op])
   3025       sess.run([ma_update_op, ms_update_op])
   3026 
   3027       self.assertAlmostEqual(32.0 / 6, mae.eval(), 5)
   3028       self.assertAlmostEqual(208.0 / 6, mse.eval(), 5)
   3029 
   3030 
   3031 class RootMeanSquaredErrorTest(test.TestCase):
   3032 
   3033   def setUp(self):
   3034     ops.reset_default_graph()
   3035 
   3036   def testVars(self):
   3037     metrics.root_mean_squared_error(
   3038         predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1)))
   3039     _assert_metric_variables(
   3040         self,
   3041         ('root_mean_squared_error/count:0', 'root_mean_squared_error/total:0'))
   3042 
   3043   def testMetricsCollection(self):
   3044     my_collection_name = '__metrics__'
   3045     mean, _ = metrics.root_mean_squared_error(
   3046         predictions=array_ops.ones((10, 1)),
   3047         labels=array_ops.ones((10, 1)),
   3048         metrics_collections=[my_collection_name])
   3049     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   3050 
   3051   def testUpdatesCollection(self):
   3052     my_collection_name = '__updates__'
   3053     _, update_op = metrics.root_mean_squared_error(
   3054         predictions=array_ops.ones((10, 1)),
   3055         labels=array_ops.ones((10, 1)),
   3056         updates_collections=[my_collection_name])
   3057     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   3058 
   3059   def testValueTensorIsIdempotent(self):
   3060     predictions = random_ops.random_normal((10, 3), seed=1)
   3061     labels = random_ops.random_normal((10, 3), seed=2)
   3062     error, update_op = metrics.root_mean_squared_error(labels, predictions)
   3063 
   3064     with self.test_session() as sess:
   3065       sess.run(variables.local_variables_initializer())
   3066 
   3067       # Run several updates.
   3068       for _ in range(10):
   3069         sess.run(update_op)
   3070 
   3071       # Then verify idempotency.
   3072       initial_error = error.eval()
   3073       for _ in range(10):
   3074         self.assertEqual(initial_error, error.eval())
   3075 
   3076   def testSingleUpdateZeroError(self):
   3077     with self.test_session() as sess:
   3078       predictions = constant_op.constant(
   3079           0.0, shape=(1, 3), dtype=dtypes_lib.float32)
   3080       labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32)
   3081 
   3082       rmse, update_op = metrics.root_mean_squared_error(labels, predictions)
   3083 
   3084       sess.run(variables.local_variables_initializer())
   3085       self.assertEqual(0, sess.run(update_op))
   3086 
   3087       self.assertEqual(0, rmse.eval())
   3088 
   3089   def testSingleUpdateWithError(self):
   3090     with self.test_session() as sess:
   3091       predictions = constant_op.constant(
   3092           [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32)
   3093       labels = constant_op.constant(
   3094           [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32)
   3095 
   3096       rmse, update_op = metrics.root_mean_squared_error(labels, predictions)
   3097 
   3098       sess.run(variables.local_variables_initializer())
   3099       self.assertAlmostEqual(math.sqrt(6), update_op.eval(), 5)
   3100       self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5)
   3101 
   3102   def testSingleUpdateWithErrorAndWeights(self):
   3103     with self.test_session() as sess:
   3104       predictions = constant_op.constant(
   3105           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
   3106       labels = constant_op.constant(
   3107           [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32)
   3108       weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4))
   3109 
   3110       rmse, update_op = metrics.root_mean_squared_error(labels, predictions,
   3111                                                         weights)
   3112 
   3113       sess.run(variables.local_variables_initializer())
   3114       self.assertAlmostEqual(math.sqrt(13), sess.run(update_op))
   3115 
   3116       self.assertAlmostEqual(math.sqrt(13), rmse.eval(), 5)
   3117 
   3118 
   3119 def _reweight(predictions, labels, weights):
   3120   return (np.concatenate([[p] * int(w) for p, w in zip(predictions, weights)]),
   3121           np.concatenate([[l] * int(w) for l, w in zip(labels, weights)]))
   3122 
   3123 
   3124 class MeanCosineDistanceTest(test.TestCase):
   3125 
   3126   def setUp(self):
   3127     ops.reset_default_graph()
   3128 
   3129   def testVars(self):
   3130     metrics.mean_cosine_distance(
   3131         predictions=array_ops.ones((10, 3)),
   3132         labels=array_ops.ones((10, 3)),
   3133         dim=1)
   3134     _assert_metric_variables(self, (
   3135         'mean_cosine_distance/count:0',
   3136         'mean_cosine_distance/total:0',
   3137     ))
   3138 
   3139   def testMetricsCollection(self):
   3140     my_collection_name = '__metrics__'
   3141     mean, _ = metrics.mean_cosine_distance(
   3142         predictions=array_ops.ones((10, 3)),
   3143         labels=array_ops.ones((10, 3)),
   3144         dim=1,
   3145         metrics_collections=[my_collection_name])
   3146     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   3147 
   3148   def testUpdatesCollection(self):
   3149     my_collection_name = '__updates__'
   3150     _, update_op = metrics.mean_cosine_distance(
   3151         predictions=array_ops.ones((10, 3)),
   3152         labels=array_ops.ones((10, 3)),
   3153         dim=1,
   3154         updates_collections=[my_collection_name])
   3155     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   3156 
   3157   def testValueTensorIsIdempotent(self):
   3158     predictions = random_ops.random_normal((10, 3), seed=1)
   3159     labels = random_ops.random_normal((10, 3), seed=2)
   3160     error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=1)
   3161 
   3162     with self.test_session() as sess:
   3163       sess.run(variables.local_variables_initializer())
   3164 
   3165       # Run several updates.
   3166       for _ in range(10):
   3167         sess.run(update_op)
   3168 
   3169       # Then verify idempotency.
   3170       initial_error = error.eval()
   3171       for _ in range(10):
   3172         self.assertEqual(initial_error, error.eval())
   3173 
   3174   def testSingleUpdateZeroError(self):
   3175     np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
   3176 
   3177     predictions = constant_op.constant(
   3178         np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32)
   3179     labels = constant_op.constant(
   3180         np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32)
   3181 
   3182     error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
   3183 
   3184     with self.test_session() as sess:
   3185       sess.run(variables.local_variables_initializer())
   3186       self.assertEqual(0, sess.run(update_op))
   3187       self.assertEqual(0, error.eval())
   3188 
   3189   def testSingleUpdateWithError1(self):
   3190     np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
   3191     np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))
   3192 
   3193     predictions = constant_op.constant(
   3194         np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
   3195     labels = constant_op.constant(
   3196         np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
   3197 
   3198     error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
   3199 
   3200     with self.test_session() as sess:
   3201       sess.run(variables.local_variables_initializer())
   3202       self.assertAlmostEqual(1, sess.run(update_op), 5)
   3203       self.assertAlmostEqual(1, error.eval(), 5)
   3204 
   3205   def testSingleUpdateWithError2(self):
   3206     np_predictions = np.matrix(
   3207         ('0.819031913261206 0.567041924552012 0.087465312324590;'
   3208          '-0.665139432070255 -0.739487441769973 -0.103671883216994;'
   3209          '0.707106781186548 -0.707106781186548 0'))
   3210     np_labels = np.matrix(
   3211         ('0.819031913261206 0.567041924552012 0.087465312324590;'
   3212          '0.665139432070255 0.739487441769973 0.103671883216994;'
   3213          '0.707106781186548 0.707106781186548 0'))
   3214 
   3215     predictions = constant_op.constant(
   3216         np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
   3217     labels = constant_op.constant(
   3218         np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
   3219     error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2)
   3220 
   3221     with self.test_session() as sess:
   3222       sess.run(variables.local_variables_initializer())
   3223       self.assertAlmostEqual(1.0, sess.run(update_op), 5)
   3224       self.assertAlmostEqual(1.0, error.eval(), 5)
   3225 
   3226   def testSingleUpdateWithErrorAndWeights1(self):
   3227     np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))
   3228     np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
   3229 
   3230     predictions = constant_op.constant(
   3231         np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
   3232     labels = constant_op.constant(
   3233         np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
   3234     weights = constant_op.constant(
   3235         [1, 0, 0], shape=(3, 1, 1), dtype=dtypes_lib.float32)
   3236 
   3237     error, update_op = metrics.mean_cosine_distance(
   3238         labels, predictions, dim=2, weights=weights)
   3239 
   3240     with self.test_session() as sess:
   3241       sess.run(variables.local_variables_initializer())
   3242       self.assertEqual(0, sess.run(update_op))
   3243       self.assertEqual(0, error.eval())
   3244 
   3245   def testSingleUpdateWithErrorAndWeights2(self):
   3246     np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0'))
   3247     np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0'))
   3248 
   3249     predictions = constant_op.constant(
   3250         np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32)
   3251     labels = constant_op.constant(
   3252         np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32)
   3253     weights = constant_op.constant(
   3254         [0, 1, 1], shape=(3, 1, 1), dtype=dtypes_lib.float32)
   3255 
   3256     error, update_op = metrics.mean_cosine_distance(
   3257         labels, predictions, dim=2, weights=weights)
   3258 
   3259     with self.test_session() as sess:
   3260       sess.run(variables.local_variables_initializer())
   3261       self.assertEqual(1.5, update_op.eval())
   3262       self.assertEqual(1.5, error.eval())
   3263 
   3264 
   3265 class PcntBelowThreshTest(test.TestCase):
   3266 
   3267   def setUp(self):
   3268     ops.reset_default_graph()
   3269 
   3270   def testVars(self):
   3271     metrics.percentage_below(values=array_ops.ones((10,)), threshold=2)
   3272     _assert_metric_variables(self, (
   3273         'percentage_below_threshold/count:0',
   3274         'percentage_below_threshold/total:0',
   3275     ))
   3276 
   3277   def testMetricsCollection(self):
   3278     my_collection_name = '__metrics__'
   3279     mean, _ = metrics.percentage_below(
   3280         values=array_ops.ones((10,)),
   3281         threshold=2,
   3282         metrics_collections=[my_collection_name])
   3283     self.assertListEqual(ops.get_collection(my_collection_name), [mean])
   3284 
   3285   def testUpdatesCollection(self):
   3286     my_collection_name = '__updates__'
   3287     _, update_op = metrics.percentage_below(
   3288         values=array_ops.ones((10,)),
   3289         threshold=2,
   3290         updates_collections=[my_collection_name])
   3291     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   3292 
   3293   def testOneUpdate(self):
   3294     with self.test_session() as sess:
   3295       values = constant_op.constant(
   3296           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
   3297 
   3298       pcnt0, update_op0 = metrics.percentage_below(values, 100, name='high')
   3299       pcnt1, update_op1 = metrics.percentage_below(values, 7, name='medium')
   3300       pcnt2, update_op2 = metrics.percentage_below(values, 1, name='low')
   3301 
   3302       sess.run(variables.local_variables_initializer())
   3303       sess.run([update_op0, update_op1, update_op2])
   3304 
   3305       pcnt0, pcnt1, pcnt2 = sess.run([pcnt0, pcnt1, pcnt2])
   3306       self.assertAlmostEqual(1.0, pcnt0, 5)
   3307       self.assertAlmostEqual(0.75, pcnt1, 5)
   3308       self.assertAlmostEqual(0.0, pcnt2, 5)
   3309 
   3310   def testSomePresentOneUpdate(self):
   3311     with self.test_session() as sess:
   3312       values = constant_op.constant(
   3313           [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32)
   3314       weights = constant_op.constant(
   3315           [1, 0, 0, 1], shape=(1, 4), dtype=dtypes_lib.float32)
   3316 
   3317       pcnt0, update_op0 = metrics.percentage_below(
   3318           values, 100, weights=weights, name='high')
   3319       pcnt1, update_op1 = metrics.percentage_below(
   3320           values, 7, weights=weights, name='medium')
   3321       pcnt2, update_op2 = metrics.percentage_below(
   3322           values, 1, weights=weights, name='low')
   3323 
   3324       sess.run(variables.local_variables_initializer())
   3325       self.assertListEqual([1.0, 0.5, 0.0],
   3326                            sess.run([update_op0, update_op1, update_op2]))
   3327 
   3328       pcnt0, pcnt1, pcnt2 = sess.run([pcnt0, pcnt1, pcnt2])
   3329       self.assertAlmostEqual(1.0, pcnt0, 5)
   3330       self.assertAlmostEqual(0.5, pcnt1, 5)
   3331       self.assertAlmostEqual(0.0, pcnt2, 5)
   3332 
   3333 
   3334 class MeanIOUTest(test.TestCase):
   3335 
   3336   def setUp(self):
   3337     np.random.seed(1)
   3338     ops.reset_default_graph()
   3339 
   3340   def testVars(self):
   3341     metrics.mean_iou(
   3342         predictions=array_ops.ones([10, 1]),
   3343         labels=array_ops.ones([10, 1]),
   3344         num_classes=2)
   3345     _assert_metric_variables(self, ('mean_iou/total_confusion_matrix:0',))
   3346 
   3347   def testMetricsCollections(self):
   3348     my_collection_name = '__metrics__'
   3349     mean_iou, _ = metrics.mean_iou(
   3350         predictions=array_ops.ones([10, 1]),
   3351         labels=array_ops.ones([10, 1]),
   3352         num_classes=2,
   3353         metrics_collections=[my_collection_name])
   3354     self.assertListEqual(ops.get_collection(my_collection_name), [mean_iou])
   3355 
   3356   def testUpdatesCollection(self):
   3357     my_collection_name = '__updates__'
   3358     _, update_op = metrics.mean_iou(
   3359         predictions=array_ops.ones([10, 1]),
   3360         labels=array_ops.ones([10, 1]),
   3361         num_classes=2,
   3362         updates_collections=[my_collection_name])
   3363     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   3364 
   3365   def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
   3366     predictions = array_ops.ones([10, 3])
   3367     labels = array_ops.ones([10, 4])
   3368     with self.assertRaises(ValueError):
   3369       metrics.mean_iou(labels, predictions, num_classes=2)
   3370 
   3371   def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self):
   3372     predictions = array_ops.ones([10])
   3373     labels = array_ops.ones([10])
   3374     weights = array_ops.zeros([9])
   3375     with self.assertRaises(ValueError):
   3376       metrics.mean_iou(labels, predictions, num_classes=2, weights=weights)
   3377 
   3378   def testValueTensorIsIdempotent(self):
   3379     num_classes = 3
   3380     predictions = random_ops.random_uniform(
   3381         [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
   3382     labels = random_ops.random_uniform(
   3383         [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
   3384     mean_iou, update_op = metrics.mean_iou(
   3385         labels, predictions, num_classes=num_classes)
   3386 
   3387     with self.test_session() as sess:
   3388       sess.run(variables.local_variables_initializer())
   3389 
   3390       # Run several updates.
   3391       for _ in range(10):
   3392         sess.run(update_op)
   3393 
   3394       # Then verify idempotency.
   3395       initial_mean_iou = mean_iou.eval()
   3396       for _ in range(10):
   3397         self.assertEqual(initial_mean_iou, mean_iou.eval())
   3398 
   3399   def testMultipleUpdates(self):
   3400     num_classes = 3
   3401     with self.test_session() as sess:
   3402       # Create the queue that populates the predictions.
   3403       preds_queue = data_flow_ops.FIFOQueue(
   3404           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3405       _enqueue_vector(sess, preds_queue, [0])
   3406       _enqueue_vector(sess, preds_queue, [1])
   3407       _enqueue_vector(sess, preds_queue, [2])
   3408       _enqueue_vector(sess, preds_queue, [1])
   3409       _enqueue_vector(sess, preds_queue, [0])
   3410       predictions = preds_queue.dequeue()
   3411 
   3412       # Create the queue that populates the labels.
   3413       labels_queue = data_flow_ops.FIFOQueue(
   3414           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3415       _enqueue_vector(sess, labels_queue, [0])
   3416       _enqueue_vector(sess, labels_queue, [1])
   3417       _enqueue_vector(sess, labels_queue, [1])
   3418       _enqueue_vector(sess, labels_queue, [2])
   3419       _enqueue_vector(sess, labels_queue, [1])
   3420       labels = labels_queue.dequeue()
   3421 
   3422       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
   3423 
   3424       sess.run(variables.local_variables_initializer())
   3425       for _ in range(5):
   3426         sess.run(update_op)
   3427       desired_output = np.mean([1.0 / 2.0, 1.0 / 4.0, 0.])
   3428       self.assertEqual(desired_output, miou.eval())
   3429 
   3430   def testMultipleUpdatesWithWeights(self):
   3431     num_classes = 2
   3432     with self.test_session() as sess:
   3433       # Create the queue that populates the predictions.
   3434       preds_queue = data_flow_ops.FIFOQueue(
   3435           6, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3436       _enqueue_vector(sess, preds_queue, [0])
   3437       _enqueue_vector(sess, preds_queue, [1])
   3438       _enqueue_vector(sess, preds_queue, [0])
   3439       _enqueue_vector(sess, preds_queue, [1])
   3440       _enqueue_vector(sess, preds_queue, [0])
   3441       _enqueue_vector(sess, preds_queue, [1])
   3442       predictions = preds_queue.dequeue()
   3443 
   3444       # Create the queue that populates the labels.
   3445       labels_queue = data_flow_ops.FIFOQueue(
   3446           6, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3447       _enqueue_vector(sess, labels_queue, [0])
   3448       _enqueue_vector(sess, labels_queue, [1])
   3449       _enqueue_vector(sess, labels_queue, [1])
   3450       _enqueue_vector(sess, labels_queue, [0])
   3451       _enqueue_vector(sess, labels_queue, [0])
   3452       _enqueue_vector(sess, labels_queue, [1])
   3453       labels = labels_queue.dequeue()
   3454 
   3455       # Create the queue that populates the weights.
   3456       weights_queue = data_flow_ops.FIFOQueue(
   3457           6, dtypes=dtypes_lib.float32, shapes=(1, 1))
   3458       _enqueue_vector(sess, weights_queue, [1.0])
   3459       _enqueue_vector(sess, weights_queue, [1.0])
   3460       _enqueue_vector(sess, weights_queue, [1.0])
   3461       _enqueue_vector(sess, weights_queue, [0.0])
   3462       _enqueue_vector(sess, weights_queue, [1.0])
   3463       _enqueue_vector(sess, weights_queue, [0.0])
   3464       weights = weights_queue.dequeue()
   3465 
   3466       mean_iou, update_op = metrics.mean_iou(
   3467           labels, predictions, num_classes, weights=weights)
   3468 
   3469       variables.local_variables_initializer().run()
   3470       for _ in range(6):
   3471         sess.run(update_op)
   3472       desired_output = np.mean([2.0 / 3.0, 1.0 / 2.0])
   3473       self.assertAlmostEqual(desired_output, mean_iou.eval())
   3474 
   3475   def testMultipleUpdatesWithMissingClass(self):
   3476     # Test the case where there are no predicions and labels for
   3477     # one class, and thus there is one row and one column with
   3478     # zero entries in the confusion matrix.
   3479     num_classes = 3
   3480     with self.test_session() as sess:
   3481       # Create the queue that populates the predictions.
   3482       # There is no prediction for class 2.
   3483       preds_queue = data_flow_ops.FIFOQueue(
   3484           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3485       _enqueue_vector(sess, preds_queue, [0])
   3486       _enqueue_vector(sess, preds_queue, [1])
   3487       _enqueue_vector(sess, preds_queue, [1])
   3488       _enqueue_vector(sess, preds_queue, [1])
   3489       _enqueue_vector(sess, preds_queue, [0])
   3490       predictions = preds_queue.dequeue()
   3491 
   3492       # Create the queue that populates the labels.
   3493       # There is label for class 2.
   3494       labels_queue = data_flow_ops.FIFOQueue(
   3495           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3496       _enqueue_vector(sess, labels_queue, [0])
   3497       _enqueue_vector(sess, labels_queue, [1])
   3498       _enqueue_vector(sess, labels_queue, [1])
   3499       _enqueue_vector(sess, labels_queue, [0])
   3500       _enqueue_vector(sess, labels_queue, [1])
   3501       labels = labels_queue.dequeue()
   3502 
   3503       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
   3504 
   3505       sess.run(variables.local_variables_initializer())
   3506       for _ in range(5):
   3507         sess.run(update_op)
   3508       desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0])
   3509       self.assertAlmostEqual(desired_output, miou.eval())
   3510 
   3511   def testUpdateOpEvalIsAccumulatedConfusionMatrix(self):
   3512     predictions = array_ops.concat(
   3513         [
   3514             constant_op.constant(
   3515                 0, shape=[5]), constant_op.constant(
   3516                     1, shape=[5])
   3517         ],
   3518         0)
   3519     labels = array_ops.concat(
   3520         [
   3521             constant_op.constant(
   3522                 0, shape=[3]), constant_op.constant(
   3523                     1, shape=[7])
   3524         ],
   3525         0)
   3526     num_classes = 2
   3527     with self.test_session() as sess:
   3528       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
   3529       sess.run(variables.local_variables_initializer())
   3530       confusion_matrix = update_op.eval()
   3531       self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix)
   3532       desired_miou = np.mean([3. / 5., 5. / 7.])
   3533       self.assertAlmostEqual(desired_miou, miou.eval())
   3534 
   3535   def testAllCorrect(self):
   3536     predictions = array_ops.zeros([40])
   3537     labels = array_ops.zeros([40])
   3538     num_classes = 1
   3539     with self.test_session() as sess:
   3540       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
   3541       sess.run(variables.local_variables_initializer())
   3542       self.assertEqual(40, update_op.eval()[0])
   3543       self.assertEqual(1.0, miou.eval())
   3544 
   3545   def testAllWrong(self):
   3546     predictions = array_ops.zeros([40])
   3547     labels = array_ops.ones([40])
   3548     num_classes = 2
   3549     with self.test_session() as sess:
   3550       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
   3551       sess.run(variables.local_variables_initializer())
   3552       self.assertAllEqual([[0, 0], [40, 0]], update_op.eval())
   3553       self.assertEqual(0., miou.eval())
   3554 
   3555   def testResultsWithSomeMissing(self):
   3556     predictions = array_ops.concat(
   3557         [
   3558             constant_op.constant(
   3559                 0, shape=[5]), constant_op.constant(
   3560                     1, shape=[5])
   3561         ],
   3562         0)
   3563     labels = array_ops.concat(
   3564         [
   3565             constant_op.constant(
   3566                 0, shape=[3]), constant_op.constant(
   3567                     1, shape=[7])
   3568         ],
   3569         0)
   3570     num_classes = 2
   3571     weights = array_ops.concat(
   3572         [
   3573             constant_op.constant(
   3574                 0, shape=[1]), constant_op.constant(
   3575                     1, shape=[8]), constant_op.constant(
   3576                         0, shape=[1])
   3577         ],
   3578         0)
   3579     with self.test_session() as sess:
   3580       miou, update_op = metrics.mean_iou(
   3581           labels, predictions, num_classes, weights=weights)
   3582       sess.run(variables.local_variables_initializer())
   3583       self.assertAllEqual([[2, 0], [2, 4]], update_op.eval())
   3584       desired_miou = np.mean([2. / 4., 4. / 6.])
   3585       self.assertAlmostEqual(desired_miou, miou.eval())
   3586 
   3587   def testMissingClassInLabels(self):
   3588     labels = constant_op.constant([
   3589         [[0, 0, 1, 1, 0, 0],
   3590          [1, 0, 0, 0, 0, 1]],
   3591         [[1, 1, 1, 1, 1, 1],
   3592          [0, 0, 0, 0, 0, 0]]])
   3593     predictions = constant_op.constant([
   3594         [[0, 0, 2, 1, 1, 0],
   3595          [0, 1, 2, 2, 0, 1]],
   3596         [[0, 0, 2, 1, 1, 1],
   3597          [1, 1, 2, 0, 0, 0]]])
   3598     num_classes = 3
   3599     with self.test_session() as sess:
   3600       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
   3601       sess.run(variables.local_variables_initializer())
   3602       self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval())
   3603       self.assertAlmostEqual(
   3604           1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)),
   3605           miou.eval())
   3606 
   3607   def testMissingClassOverallSmall(self):
   3608     labels = constant_op.constant([0])
   3609     predictions = constant_op.constant([0])
   3610     num_classes = 2
   3611     with self.test_session() as sess:
   3612       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
   3613       sess.run(variables.local_variables_initializer())
   3614       self.assertAllEqual([[1, 0], [0, 0]], update_op.eval())
   3615       self.assertAlmostEqual(1, miou.eval())
   3616 
   3617   def testMissingClassOverallLarge(self):
   3618     labels = constant_op.constant([
   3619         [[0, 0, 1, 1, 0, 0],
   3620          [1, 0, 0, 0, 0, 1]],
   3621         [[1, 1, 1, 1, 1, 1],
   3622          [0, 0, 0, 0, 0, 0]]])
   3623     predictions = constant_op.constant([
   3624         [[0, 0, 1, 1, 0, 0],
   3625          [1, 1, 0, 0, 1, 1]],
   3626         [[0, 0, 0, 1, 1, 1],
   3627          [1, 1, 1, 0, 0, 0]]])
   3628     num_classes = 3
   3629     with self.test_session() as sess:
   3630       miou, update_op = metrics.mean_iou(labels, predictions, num_classes)
   3631       sess.run(variables.local_variables_initializer())
   3632       self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval())
   3633       self.assertAlmostEqual(
   3634           1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval())
   3635 
   3636 
   3637 class MeanPerClassAccuracyTest(test.TestCase):
   3638 
   3639   def setUp(self):
   3640     np.random.seed(1)
   3641     ops.reset_default_graph()
   3642 
   3643   def testVars(self):
   3644     metrics.mean_per_class_accuracy(
   3645         predictions=array_ops.ones([10, 1]),
   3646         labels=array_ops.ones([10, 1]),
   3647         num_classes=2)
   3648     _assert_metric_variables(self, ('mean_accuracy/count:0',
   3649                                     'mean_accuracy/total:0'))
   3650 
   3651   def testMetricsCollections(self):
   3652     my_collection_name = '__metrics__'
   3653     mean_accuracy, _ = metrics.mean_per_class_accuracy(
   3654         predictions=array_ops.ones([10, 1]),
   3655         labels=array_ops.ones([10, 1]),
   3656         num_classes=2,
   3657         metrics_collections=[my_collection_name])
   3658     self.assertListEqual(
   3659         ops.get_collection(my_collection_name), [mean_accuracy])
   3660 
   3661   def testUpdatesCollection(self):
   3662     my_collection_name = '__updates__'
   3663     _, update_op = metrics.mean_per_class_accuracy(
   3664         predictions=array_ops.ones([10, 1]),
   3665         labels=array_ops.ones([10, 1]),
   3666         num_classes=2,
   3667         updates_collections=[my_collection_name])
   3668     self.assertListEqual(ops.get_collection(my_collection_name), [update_op])
   3669 
   3670   def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self):
   3671     predictions = array_ops.ones([10, 3])
   3672     labels = array_ops.ones([10, 4])
   3673     with self.assertRaises(ValueError):
   3674       metrics.mean_per_class_accuracy(labels, predictions, num_classes=2)
   3675 
   3676   def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self):
   3677     predictions = array_ops.ones([10])
   3678     labels = array_ops.ones([10])
   3679     weights = array_ops.zeros([9])
   3680     with self.assertRaises(ValueError):
   3681       metrics.mean_per_class_accuracy(
   3682           labels, predictions, num_classes=2, weights=weights)
   3683 
   3684   def testValueTensorIsIdempotent(self):
   3685     num_classes = 3
   3686     predictions = random_ops.random_uniform(
   3687         [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
   3688     labels = random_ops.random_uniform(
   3689         [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1)
   3690     mean_accuracy, update_op = metrics.mean_per_class_accuracy(
   3691         labels, predictions, num_classes=num_classes)
   3692 
   3693     with self.test_session() as sess:
   3694       sess.run(variables.local_variables_initializer())
   3695 
   3696       # Run several updates.
   3697       for _ in range(10):
   3698         sess.run(update_op)
   3699 
   3700       # Then verify idempotency.
   3701       initial_mean_accuracy = mean_accuracy.eval()
   3702       for _ in range(10):
   3703         self.assertEqual(initial_mean_accuracy, mean_accuracy.eval())
   3704 
   3705     num_classes = 3
   3706     with self.test_session() as sess:
   3707       # Create the queue that populates the predictions.
   3708       preds_queue = data_flow_ops.FIFOQueue(
   3709           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3710       _enqueue_vector(sess, preds_queue, [0])
   3711       _enqueue_vector(sess, preds_queue, [1])
   3712       _enqueue_vector(sess, preds_queue, [2])
   3713       _enqueue_vector(sess, preds_queue, [1])
   3714       _enqueue_vector(sess, preds_queue, [0])
   3715       predictions = preds_queue.dequeue()
   3716 
   3717       # Create the queue that populates the labels.
   3718       labels_queue = data_flow_ops.FIFOQueue(
   3719           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3720       _enqueue_vector(sess, labels_queue, [0])
   3721       _enqueue_vector(sess, labels_queue, [1])
   3722       _enqueue_vector(sess, labels_queue, [1])
   3723       _enqueue_vector(sess, labels_queue, [2])
   3724       _enqueue_vector(sess, labels_queue, [1])
   3725       labels = labels_queue.dequeue()
   3726 
   3727       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
   3728           labels, predictions, num_classes)
   3729 
   3730       sess.run(variables.local_variables_initializer())
   3731       for _ in range(5):
   3732         sess.run(update_op)
   3733       desired_output = np.mean([1.0, 1.0 / 3.0, 0.0])
   3734       self.assertAlmostEqual(desired_output, mean_accuracy.eval())
   3735 
   3736   def testMultipleUpdatesWithWeights(self):
   3737     num_classes = 2
   3738     with self.test_session() as sess:
   3739       # Create the queue that populates the predictions.
   3740       preds_queue = data_flow_ops.FIFOQueue(
   3741           6, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3742       _enqueue_vector(sess, preds_queue, [0])
   3743       _enqueue_vector(sess, preds_queue, [1])
   3744       _enqueue_vector(sess, preds_queue, [0])
   3745       _enqueue_vector(sess, preds_queue, [1])
   3746       _enqueue_vector(sess, preds_queue, [0])
   3747       _enqueue_vector(sess, preds_queue, [1])
   3748       predictions = preds_queue.dequeue()
   3749 
   3750       # Create the queue that populates the labels.
   3751       labels_queue = data_flow_ops.FIFOQueue(
   3752           6, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3753       _enqueue_vector(sess, labels_queue, [0])
   3754       _enqueue_vector(sess, labels_queue, [1])
   3755       _enqueue_vector(sess, labels_queue, [1])
   3756       _enqueue_vector(sess, labels_queue, [0])
   3757       _enqueue_vector(sess, labels_queue, [0])
   3758       _enqueue_vector(sess, labels_queue, [1])
   3759       labels = labels_queue.dequeue()
   3760 
   3761       # Create the queue that populates the weights.
   3762       weights_queue = data_flow_ops.FIFOQueue(
   3763           6, dtypes=dtypes_lib.float32, shapes=(1, 1))
   3764       _enqueue_vector(sess, weights_queue, [1.0])
   3765       _enqueue_vector(sess, weights_queue, [0.5])
   3766       _enqueue_vector(sess, weights_queue, [1.0])
   3767       _enqueue_vector(sess, weights_queue, [0.0])
   3768       _enqueue_vector(sess, weights_queue, [1.0])
   3769       _enqueue_vector(sess, weights_queue, [0.0])
   3770       weights = weights_queue.dequeue()
   3771 
   3772       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
   3773           labels, predictions, num_classes, weights=weights)
   3774 
   3775       variables.local_variables_initializer().run()
   3776       for _ in range(6):
   3777         sess.run(update_op)
   3778       desired_output = np.mean([2.0 / 2.0, 0.5 / 1.5])
   3779       self.assertAlmostEqual(desired_output, mean_accuracy.eval())
   3780 
   3781   def testMultipleUpdatesWithMissingClass(self):
   3782     # Test the case where there are no predicions and labels for
   3783     # one class, and thus there is one row and one column with
   3784     # zero entries in the confusion matrix.
   3785     num_classes = 3
   3786     with self.test_session() as sess:
   3787       # Create the queue that populates the predictions.
   3788       # There is no prediction for class 2.
   3789       preds_queue = data_flow_ops.FIFOQueue(
   3790           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3791       _enqueue_vector(sess, preds_queue, [0])
   3792       _enqueue_vector(sess, preds_queue, [1])
   3793       _enqueue_vector(sess, preds_queue, [1])
   3794       _enqueue_vector(sess, preds_queue, [1])
   3795       _enqueue_vector(sess, preds_queue, [0])
   3796       predictions = preds_queue.dequeue()
   3797 
   3798       # Create the queue that populates the labels.
   3799       # There is label for class 2.
   3800       labels_queue = data_flow_ops.FIFOQueue(
   3801           5, dtypes=dtypes_lib.int32, shapes=(1, 1))
   3802       _enqueue_vector(sess, labels_queue, [0])
   3803       _enqueue_vector(sess, labels_queue, [1])
   3804       _enqueue_vector(sess, labels_queue, [1])
   3805       _enqueue_vector(sess, labels_queue, [0])
   3806       _enqueue_vector(sess, labels_queue, [1])
   3807       labels = labels_queue.dequeue()
   3808 
   3809       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
   3810           labels, predictions, num_classes)
   3811 
   3812       sess.run(variables.local_variables_initializer())
   3813       for _ in range(5):
   3814         sess.run(update_op)
   3815       desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.])
   3816       self.assertAlmostEqual(desired_output, mean_accuracy.eval())
   3817 
   3818   def testAllCorrect(self):
   3819     predictions = array_ops.zeros([40])
   3820     labels = array_ops.zeros([40])
   3821     num_classes = 1
   3822     with self.test_session() as sess:
   3823       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
   3824           labels, predictions, num_classes)
   3825       sess.run(variables.local_variables_initializer())
   3826       self.assertEqual(1.0, update_op.eval()[0])
   3827       self.assertEqual(1.0, mean_accuracy.eval())
   3828 
   3829   def testAllWrong(self):
   3830     predictions = array_ops.zeros([40])
   3831     labels = array_ops.ones([40])
   3832     num_classes = 2
   3833     with self.test_session() as sess:
   3834       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
   3835           labels, predictions, num_classes)
   3836       sess.run(variables.local_variables_initializer())
   3837       self.assertAllEqual([0.0, 0.0], update_op.eval())
   3838       self.assertEqual(0., mean_accuracy.eval())
   3839 
   3840   def testResultsWithSomeMissing(self):
   3841     predictions = array_ops.concat([
   3842         constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5])
   3843     ], 0)
   3844     labels = array_ops.concat([
   3845         constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7])
   3846     ], 0)
   3847     num_classes = 2
   3848     weights = array_ops.concat([
   3849         constant_op.constant(0, shape=[1]), constant_op.constant(1, shape=[8]),
   3850         constant_op.constant(0, shape=[1])
   3851     ], 0)
   3852     with self.test_session() as sess:
   3853       mean_accuracy, update_op = metrics.mean_per_class_accuracy(
   3854           labels, predictions, num_classes, weights=weights)
   3855       sess.run(variables.local_variables_initializer())
   3856       desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32)
   3857       self.assertAllEqual(desired_accuracy, update_op.eval())
   3858       desired_mean_accuracy = np.mean(desired_accuracy)
   3859       self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval())
   3860 
   3861 
   3862 class FalseNegativesTest(test.TestCase):
   3863 
   3864   def setUp(self):
   3865     np.random.seed(1)
   3866     ops.reset_default_graph()
   3867 
   3868   def testVars(self):
   3869     metrics.false_negatives(
   3870         labels=(0, 1, 0, 1),
   3871         predictions=(0, 0, 1, 1))
   3872     _assert_metric_variables(self, ('false_negatives/count:0',))
   3873 
   3874   def testUnweighted(self):
   3875     labels = constant_op.constant(((0, 1, 0, 1, 0),
   3876                                    (0, 0, 1, 1, 1),
   3877                                    (1, 1, 1, 1, 0),
   3878                                    (0, 0, 0, 0, 1)))
   3879     predictions = constant_op.constant(((0, 0, 1, 1, 0),
   3880                                         (1, 1, 1, 1, 1),
   3881                                         (0, 1, 0, 1, 0),
   3882                                         (1, 1, 1, 1, 1)))
   3883     tn, tn_update_op = metrics.false_negatives(
   3884         labels=labels, predictions=predictions)
   3885 
   3886     with self.test_session() as sess:
   3887       sess.run(variables.local_variables_initializer())
   3888       self.assertAllClose(0., tn.eval())
   3889       self.assertAllClose(3., tn_update_op.eval())
   3890       self.assertAllClose(3., tn.eval())
   3891 
   3892   def testWeighted(self):
   3893     labels = constant_op.constant(((0, 1, 0, 1, 0),
   3894                                    (0, 0, 1, 1, 1),
   3895                                    (1, 1, 1, 1, 0),
   3896                                    (0, 0, 0, 0, 1)))
   3897     predictions = constant_op.constant(((0, 0, 1, 1, 0),
   3898                                         (1, 1, 1, 1, 1),
   3899                                         (0, 1, 0, 1, 0),
   3900                                         (1, 1, 1, 1, 1)))
   3901     weights = constant_op.constant((1., 1.5, 2., 2.5))
   3902     tn, tn_update_op = metrics.false_negatives(
   3903         labels=labels, predictions=predictions, weights=weights)
   3904 
   3905     with self.test_session() as sess:
   3906       sess.run(variables.local_variables_initializer())
   3907       self.assertAllClose(0., tn.eval())
   3908       self.assertAllClose(5., tn_update_op.eval())
   3909       self.assertAllClose(5., tn.eval())
   3910 
   3911 
   3912 class FalseNegativesAtThresholdsTest(test.TestCase):
   3913 
   3914   def setUp(self):
   3915     np.random.seed(1)
   3916     ops.reset_default_graph()
   3917 
   3918   def testVars(self):
   3919     metrics.false_negatives_at_thresholds(
   3920         predictions=array_ops.ones((10, 1)),
   3921         labels=array_ops.ones((10, 1)),
   3922         thresholds=[0.15, 0.5, 0.85])
   3923     _assert_metric_variables(self, ('false_negatives/false_negatives:0',))
   3924 
   3925   def testUnweighted(self):
   3926     predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
   3927                                         (0.2, 0.9, 0.7, 0.6),
   3928                                         (0.1, 0.2, 0.4, 0.3)))
   3929     labels = constant_op.constant(((0, 1, 1, 0),
   3930                                    (1, 0, 0, 0),
   3931                                    (0, 0, 0, 0)))
   3932     fn, fn_update_op = metrics.false_negatives_at_thresholds(
   3933         predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
   3934 
   3935     with self.test_session() as sess:
   3936       sess.run(variables.local_variables_initializer())
   3937       self.assertAllEqual((0, 0, 0), fn.eval())
   3938       self.assertAllEqual((0, 2, 3), fn_update_op.eval())
   3939       self.assertAllEqual((0, 2, 3), fn.eval())
   3940 
   3941   def testWeighted(self):
   3942     predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
   3943                                         (0.2, 0.9, 0.7, 0.6),
   3944                                         (0.1, 0.2, 0.4, 0.3)))
   3945     labels = constant_op.constant(((0, 1, 1, 0),
   3946                                    (1, 0, 0, 0),
   3947                                    (0, 0, 0, 0)))
   3948     fn, fn_update_op = metrics.false_negatives_at_thresholds(
   3949         predictions=predictions,
   3950         labels=labels,
   3951         weights=((3.0,), (5.0,), (7.0,)),
   3952         thresholds=[0.15, 0.5, 0.85])
   3953 
   3954     with self.test_session() as sess:
   3955       sess.run(variables.local_variables_initializer())
   3956       self.assertAllEqual((0.0, 0.0, 0.0), fn.eval())
   3957       self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval())
   3958       self.assertAllEqual((0.0, 8.0, 11.0), fn.eval())
   3959 
   3960 
   3961 class FalsePositivesTest(test.TestCase):
   3962 
   3963   def setUp(self):
   3964     np.random.seed(1)
   3965     ops.reset_default_graph()
   3966 
   3967   def testVars(self):
   3968     metrics.false_positives(
   3969         labels=(0, 1, 0, 1),
   3970         predictions=(0, 0, 1, 1))
   3971     _assert_metric_variables(self, ('false_positives/count:0',))
   3972 
   3973   def testUnweighted(self):
   3974     labels = constant_op.constant(((0, 1, 0, 1, 0),
   3975                                    (0, 0, 1, 1, 1),
   3976                                    (1, 1, 1, 1, 0),
   3977                                    (0, 0, 0, 0, 1)))
   3978     predictions = constant_op.constant(((0, 0, 1, 1, 0),
   3979                                         (1, 1, 1, 1, 1),
   3980                                         (0, 1, 0, 1, 0),
   3981                                         (1, 1, 1, 1, 1)))
   3982     tn, tn_update_op = metrics.false_positives(
   3983         labels=labels, predictions=predictions)
   3984 
   3985     with self.test_session() as sess:
   3986       sess.run(variables.local_variables_initializer())
   3987       self.assertAllClose(0., tn.eval())
   3988       self.assertAllClose(7., tn_update_op.eval())
   3989       self.assertAllClose(7., tn.eval())
   3990 
   3991   def testWeighted(self):
   3992     labels = constant_op.constant(((0, 1, 0, 1, 0),
   3993                                    (0, 0, 1, 1, 1),
   3994                                    (1, 1, 1, 1, 0),
   3995                                    (0, 0, 0, 0, 1)))
   3996     predictions = constant_op.constant(((0, 0, 1, 1, 0),
   3997                                         (1, 1, 1, 1, 1),
   3998                                         (0, 1, 0, 1, 0),
   3999                                         (1, 1, 1, 1, 1)))
   4000     weights = constant_op.constant((1., 1.5, 2., 2.5))
   4001     tn, tn_update_op = metrics.false_positives(
   4002         labels=labels, predictions=predictions, weights=weights)
   4003 
   4004     with self.test_session() as sess:
   4005       sess.run(variables.local_variables_initializer())
   4006       self.assertAllClose(0., tn.eval())
   4007       self.assertAllClose(14., tn_update_op.eval())
   4008       self.assertAllClose(14., tn.eval())
   4009 
   4010 
   4011 class FalsePositivesAtThresholdsTest(test.TestCase):
   4012 
   4013   def setUp(self):
   4014     np.random.seed(1)
   4015     ops.reset_default_graph()
   4016 
   4017   def testVars(self):
   4018     metrics.false_positives_at_thresholds(
   4019         predictions=array_ops.ones((10, 1)),
   4020         labels=array_ops.ones((10, 1)),
   4021         thresholds=[0.15, 0.5, 0.85])
   4022     _assert_metric_variables(self, ('false_positives/false_positives:0',))
   4023 
   4024   def testUnweighted(self):
   4025     predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
   4026                                         (0.2, 0.9, 0.7, 0.6),
   4027                                         (0.1, 0.2, 0.4, 0.3)))
   4028     labels = constant_op.constant(((0, 1, 1, 0),
   4029                                    (1, 0, 0, 0),
   4030                                    (0, 0, 0, 0)))
   4031     fp, fp_update_op = metrics.false_positives_at_thresholds(
   4032         predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
   4033 
   4034     with self.test_session() as sess:
   4035       sess.run(variables.local_variables_initializer())
   4036       self.assertAllEqual((0, 0, 0), fp.eval())
   4037       self.assertAllEqual((7, 4, 2), fp_update_op.eval())
   4038       self.assertAllEqual((7, 4, 2), fp.eval())
   4039 
   4040   def testWeighted(self):
   4041     predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
   4042                                         (0.2, 0.9, 0.7, 0.6),
   4043                                         (0.1, 0.2, 0.4, 0.3)))
   4044     labels = constant_op.constant(((0, 1, 1, 0),
   4045                                    (1, 0, 0, 0),
   4046                                    (0, 0, 0, 0)))
   4047     fp, fp_update_op = metrics.false_positives_at_thresholds(
   4048         predictions=predictions,
   4049         labels=labels,
   4050         weights=((1.0, 2.0, 3.0, 5.0),
   4051                  (7.0, 11.0, 13.0, 17.0),
   4052                  (19.0, 23.0, 29.0, 31.0)),
   4053         thresholds=[0.15, 0.5, 0.85])
   4054 
   4055     with self.test_session() as sess:
   4056       sess.run(variables.local_variables_initializer())
   4057       self.assertAllEqual((0.0, 0.0, 0.0), fp.eval())
   4058       self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval())
   4059       self.assertAllEqual((125.0, 42.0, 12.0), fp.eval())
   4060 
   4061 
   4062 class TrueNegativesTest(test.TestCase):
   4063 
   4064   def setUp(self):
   4065     np.random.seed(1)
   4066     ops.reset_default_graph()
   4067 
   4068   def testVars(self):
   4069     metrics.true_negatives(
   4070         labels=(0, 1, 0, 1),
   4071         predictions=(0, 0, 1, 1))
   4072     _assert_metric_variables(self, ('true_negatives/count:0',))
   4073 
   4074   def testUnweighted(self):
   4075     labels = constant_op.constant(((0, 1, 0, 1, 0),
   4076                                    (0, 0, 1, 1, 1),
   4077                                    (1, 1, 1, 1, 0),
   4078                                    (0, 0, 0, 0, 1)))
   4079     predictions = constant_op.constant(((0, 0, 1, 1, 0),
   4080                                         (1, 1, 1, 1, 1),
   4081                                         (0, 1, 0, 1, 0),
   4082                                         (1, 1, 1, 1, 1)))
   4083     tn, tn_update_op = metrics.true_negatives(
   4084         labels=labels, predictions=predictions)
   4085 
   4086     with self.test_session() as sess:
   4087       sess.run(variables.local_variables_initializer())
   4088       self.assertAllClose(0., tn.eval())
   4089       self.assertAllClose(3., tn_update_op.eval())
   4090       self.assertAllClose(3., tn.eval())
   4091 
   4092   def testWeighted(self):
   4093     labels = constant_op.constant(((0, 1, 0, 1, 0),
   4094                                    (0, 0, 1, 1, 1),
   4095                                    (1, 1, 1, 1, 0),
   4096                                    (0, 0, 0, 0, 1)))
   4097     predictions = constant_op.constant(((0, 0, 1, 1, 0),
   4098                                         (1, 1, 1, 1, 1),
   4099                                         (0, 1, 0, 1, 0),
   4100                                         (1, 1, 1, 1, 1)))
   4101     weights = constant_op.constant((1., 1.5, 2., 2.5))
   4102     tn, tn_update_op = metrics.true_negatives(
   4103         labels=labels, predictions=predictions, weights=weights)
   4104 
   4105     with self.test_session() as sess:
   4106       sess.run(variables.local_variables_initializer())
   4107       self.assertAllClose(0., tn.eval())
   4108       self.assertAllClose(4., tn_update_op.eval())
   4109       self.assertAllClose(4., tn.eval())
   4110 
   4111 
   4112 class TrueNegativesAtThresholdsTest(test.TestCase):
   4113 
   4114   def setUp(self):
   4115     np.random.seed(1)
   4116     ops.reset_default_graph()
   4117 
   4118   def testVars(self):
   4119     metrics.true_negatives_at_thresholds(
   4120         predictions=array_ops.ones((10, 1)),
   4121         labels=array_ops.ones((10, 1)),
   4122         thresholds=[0.15, 0.5, 0.85])
   4123     _assert_metric_variables(self, ('true_negatives/true_negatives:0',))
   4124 
   4125   def testUnweighted(self):
   4126     predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
   4127                                         (0.2, 0.9, 0.7, 0.6),
   4128                                         (0.1, 0.2, 0.4, 0.3)))
   4129     labels = constant_op.constant(((0, 1, 1, 0),
   4130                                    (1, 0, 0, 0),
   4131                                    (0, 0, 0, 0)))
   4132     tn, tn_update_op = metrics.true_negatives_at_thresholds(
   4133         predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
   4134 
   4135     with self.test_session() as sess:
   4136       sess.run(variables.local_variables_initializer())
   4137       self.assertAllEqual((0, 0, 0), tn.eval())
   4138       self.assertAllEqual((2, 5, 7), tn_update_op.eval())
   4139       self.assertAllEqual((2, 5, 7), tn.eval())
   4140 
   4141   def testWeighted(self):
   4142     predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
   4143                                         (0.2, 0.9, 0.7, 0.6),
   4144                                         (0.1, 0.2, 0.4, 0.3)))
   4145     labels = constant_op.constant(((0, 1, 1, 0),
   4146                                    (1, 0, 0, 0),
   4147                                    (0, 0, 0, 0)))
   4148     tn, tn_update_op = metrics.true_negatives_at_thresholds(
   4149         predictions=predictions,
   4150         labels=labels,
   4151         weights=((0.0, 2.0, 3.0, 5.0),),
   4152         thresholds=[0.15, 0.5, 0.85])
   4153 
   4154     with self.test_session() as sess:
   4155       sess.run(variables.local_variables_initializer())
   4156       self.assertAllEqual((0.0, 0.0, 0.0), tn.eval())
   4157       self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval())
   4158       self.assertAllEqual((5.0, 15.0, 23.0), tn.eval())
   4159 
   4160 
   4161 class TruePositivesTest(test.TestCase):
   4162 
   4163   def setUp(self):
   4164     np.random.seed(1)
   4165     ops.reset_default_graph()
   4166 
   4167   def testVars(self):
   4168     metrics.true_positives(
   4169         labels=(0, 1, 0, 1),
   4170         predictions=(0, 0, 1, 1))
   4171     _assert_metric_variables(self, ('true_positives/count:0',))
   4172 
   4173   def testUnweighted(self):
   4174     labels = constant_op.constant(((0, 1, 0, 1, 0),
   4175                                    (0, 0, 1, 1, 1),
   4176                                    (1, 1, 1, 1, 0),
   4177                                    (0, 0, 0, 0, 1)))
   4178     predictions = constant_op.constant(((0, 0, 1, 1, 0),
   4179                                         (1, 1, 1, 1, 1),
   4180                                         (0, 1, 0, 1, 0),
   4181                                         (1, 1, 1, 1, 1)))
   4182     tn, tn_update_op = metrics.true_positives(
   4183         labels=labels, predictions=predictions)
   4184 
   4185     with self.test_session() as sess:
   4186       sess.run(variables.local_variables_initializer())
   4187       self.assertAllClose(0., tn.eval())
   4188       self.assertAllClose(7., tn_update_op.eval())
   4189       self.assertAllClose(7., tn.eval())
   4190 
   4191   def testWeighted(self):
   4192     labels = constant_op.constant(((0, 1, 0, 1, 0),
   4193                                    (0, 0, 1, 1, 1),
   4194                                    (1, 1, 1, 1, 0),
   4195                                    (0, 0, 0, 0, 1)))
   4196     predictions = constant_op.constant(((0, 0, 1, 1, 0),
   4197                                         (1, 1, 1, 1, 1),
   4198                                         (0, 1, 0, 1, 0),
   4199                                         (1, 1, 1, 1, 1)))
   4200     weights = constant_op.constant((1., 1.5, 2., 2.5))
   4201     tn, tn_update_op = metrics.true_positives(
   4202         labels=labels, predictions=predictions, weights=weights)
   4203 
   4204     with self.test_session() as sess:
   4205       sess.run(variables.local_variables_initializer())
   4206       self.assertAllClose(0., tn.eval())
   4207       self.assertAllClose(12., tn_update_op.eval())
   4208       self.assertAllClose(12., tn.eval())
   4209 
   4210 
   4211 class TruePositivesAtThresholdsTest(test.TestCase):
   4212 
   4213   def setUp(self):
   4214     np.random.seed(1)
   4215     ops.reset_default_graph()
   4216 
   4217   def testVars(self):
   4218     metrics.true_positives_at_thresholds(
   4219         predictions=array_ops.ones((10, 1)),
   4220         labels=array_ops.ones((10, 1)),
   4221         thresholds=[0.15, 0.5, 0.85])
   4222     _assert_metric_variables(self, ('true_positives/true_positives:0',))
   4223 
   4224   def testUnweighted(self):
   4225     predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
   4226                                         (0.2, 0.9, 0.7, 0.6),
   4227                                         (0.1, 0.2, 0.4, 0.3)))
   4228     labels = constant_op.constant(((0, 1, 1, 0),
   4229                                    (1, 0, 0, 0),
   4230                                    (0, 0, 0, 0)))
   4231     tp, tp_update_op = metrics.true_positives_at_thresholds(
   4232         predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85])
   4233 
   4234     with self.test_session() as sess:
   4235       sess.run(variables.local_variables_initializer())
   4236       self.assertAllEqual((0, 0, 0), tp.eval())
   4237       self.assertAllEqual((3, 1, 0), tp_update_op.eval())
   4238       self.assertAllEqual((3, 1, 0), tp.eval())
   4239 
   4240   def testWeighted(self):
   4241     predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1),
   4242                                         (0.2, 0.9, 0.7, 0.6),
   4243                                         (0.1, 0.2, 0.4, 0.3)))
   4244     labels = constant_op.constant(((0, 1, 1, 0),
   4245                                    (1, 0, 0, 0),
   4246                                    (0, 0, 0, 0)))
   4247     tp, tp_update_op = metrics.true_positives_at_thresholds(
   4248         predictions=predictions, labels=labels, weights=37.0,
   4249         thresholds=[0.15, 0.5, 0.85])
   4250 
   4251     with self.test_session() as sess:
   4252       sess.run(variables.local_variables_initializer())
   4253       self.assertAllEqual((0.0, 0.0, 0.0), tp.eval())
   4254       self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval())
   4255       self.assertAllEqual((111.0, 37.0, 0.0), tp.eval())
   4256 
   4257 
   4258 if __name__ == '__main__':
   4259   test.main()
   4260