1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for metrics.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import functools 22 import math 23 24 import numpy as np 25 from six.moves import xrange # pylint: disable=redefined-builtin 26 27 from tensorflow.python.framework import constant_op 28 from tensorflow.python.framework import dtypes as dtypes_lib 29 from tensorflow.python.framework import errors_impl 30 from tensorflow.python.framework import ops 31 from tensorflow.python.framework import sparse_tensor 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.ops import data_flow_ops 34 from tensorflow.python.ops import math_ops 35 from tensorflow.python.ops import metrics 36 from tensorflow.python.ops import random_ops 37 from tensorflow.python.ops import variables 38 import tensorflow.python.ops.data_flow_grad # pylint: disable=unused-import 39 import tensorflow.python.ops.nn_grad # pylint: disable=unused-import 40 from tensorflow.python.platform import test 41 42 NAN = float('nan') 43 44 45 def _enqueue_vector(sess, queue, values, shape=None): 46 if not shape: 47 shape = (1, len(values)) 48 dtype = queue.dtypes[0] 49 sess.run( 50 queue.enqueue(constant_op.constant( 51 values, dtype=dtype, shape=shape))) 52 53 54 def _binary_2d_label_to_2d_sparse_value(labels): 55 """Convert dense 2D binary indicator to sparse ID. 56 57 Only 1 values in `labels` are included in result. 58 59 Args: 60 labels: Dense 2D binary indicator, shape [batch_size, num_classes]. 61 62 Returns: 63 `SparseTensorValue` of shape [batch_size, num_classes], where num_classes 64 is the number of `1` values in each row of `labels`. Values are indices 65 of `1` values along the last dimension of `labels`. 66 """ 67 indices = [] 68 values = [] 69 batch = 0 70 for row in labels: 71 label = 0 72 xi = 0 73 for x in row: 74 if x == 1: 75 indices.append([batch, xi]) 76 values.append(label) 77 xi += 1 78 else: 79 assert x == 0 80 label += 1 81 batch += 1 82 shape = [len(labels), len(labels[0])] 83 return sparse_tensor.SparseTensorValue( 84 np.array(indices, np.int64), 85 np.array(values, np.int64), np.array(shape, np.int64)) 86 87 88 def _binary_2d_label_to_1d_sparse_value(labels): 89 """Convert dense 2D binary indicator to sparse ID. 90 91 Only 1 values in `labels` are included in result. 92 93 Args: 94 labels: Dense 2D binary indicator, shape [batch_size, num_classes]. Each 95 row must contain exactly 1 `1` value. 96 97 Returns: 98 `SparseTensorValue` of shape [batch_size]. Values are indices of `1` values 99 along the last dimension of `labels`. 100 101 Raises: 102 ValueError: if there is not exactly 1 `1` value per row of `labels`. 103 """ 104 indices = [] 105 values = [] 106 batch = 0 107 for row in labels: 108 label = 0 109 xi = 0 110 for x in row: 111 if x == 1: 112 indices.append([batch]) 113 values.append(label) 114 xi += 1 115 else: 116 assert x == 0 117 label += 1 118 batch += 1 119 if indices != [[i] for i in range(len(labels))]: 120 raise ValueError('Expected 1 label/example, got %s.' % indices) 121 shape = [len(labels)] 122 return sparse_tensor.SparseTensorValue( 123 np.array(indices, np.int64), 124 np.array(values, np.int64), np.array(shape, np.int64)) 125 126 127 def _binary_3d_label_to_sparse_value(labels): 128 """Convert dense 3D binary indicator tensor to sparse tensor. 129 130 Only 1 values in `labels` are included in result. 131 132 Args: 133 labels: Dense 2D binary indicator tensor. 134 135 Returns: 136 `SparseTensorValue` whose values are indices along the last dimension of 137 `labels`. 138 """ 139 indices = [] 140 values = [] 141 for d0, labels_d0 in enumerate(labels): 142 for d1, labels_d1 in enumerate(labels_d0): 143 d2 = 0 144 for class_id, label in enumerate(labels_d1): 145 if label == 1: 146 values.append(class_id) 147 indices.append([d0, d1, d2]) 148 d2 += 1 149 else: 150 assert label == 0 151 shape = [len(labels), len(labels[0]), len(labels[0][0])] 152 return sparse_tensor.SparseTensorValue( 153 np.array(indices, np.int64), 154 np.array(values, np.int64), np.array(shape, np.int64)) 155 156 157 def _assert_nan(test_case, actual): 158 test_case.assertTrue(math.isnan(actual), 'Expected NAN, got %s.' % actual) 159 160 161 def _assert_metric_variables(test_case, expected): 162 test_case.assertEquals( 163 set(expected), set(v.name for v in variables.local_variables())) 164 test_case.assertEquals( 165 set(expected), 166 set(v.name for v in ops.get_collection(ops.GraphKeys.METRIC_VARIABLES))) 167 168 169 def _test_values(shape): 170 return np.reshape(np.cumsum(np.ones(shape)), newshape=shape) 171 172 173 class MeanTest(test.TestCase): 174 175 def setUp(self): 176 ops.reset_default_graph() 177 178 def testVars(self): 179 metrics.mean(array_ops.ones([4, 3])) 180 _assert_metric_variables(self, ('mean/count:0', 'mean/total:0')) 181 182 def testMetricsCollection(self): 183 my_collection_name = '__metrics__' 184 mean, _ = metrics.mean( 185 array_ops.ones([4, 3]), metrics_collections=[my_collection_name]) 186 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 187 188 def testUpdatesCollection(self): 189 my_collection_name = '__updates__' 190 _, update_op = metrics.mean( 191 array_ops.ones([4, 3]), updates_collections=[my_collection_name]) 192 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 193 194 def testBasic(self): 195 with self.test_session() as sess: 196 values_queue = data_flow_ops.FIFOQueue( 197 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 198 _enqueue_vector(sess, values_queue, [0, 1]) 199 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 200 _enqueue_vector(sess, values_queue, [6.5, 0]) 201 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 202 values = values_queue.dequeue() 203 204 mean, update_op = metrics.mean(values) 205 206 sess.run(variables.local_variables_initializer()) 207 for _ in range(4): 208 sess.run(update_op) 209 self.assertAlmostEqual(1.65, sess.run(mean), 5) 210 211 def testUpdateOpsReturnsCurrentValue(self): 212 with self.test_session() as sess: 213 values_queue = data_flow_ops.FIFOQueue( 214 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 215 _enqueue_vector(sess, values_queue, [0, 1]) 216 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 217 _enqueue_vector(sess, values_queue, [6.5, 0]) 218 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 219 values = values_queue.dequeue() 220 221 mean, update_op = metrics.mean(values) 222 223 sess.run(variables.local_variables_initializer()) 224 225 self.assertAlmostEqual(0.5, sess.run(update_op), 5) 226 self.assertAlmostEqual(1.475, sess.run(update_op), 5) 227 self.assertAlmostEqual(12.4 / 6.0, sess.run(update_op), 5) 228 self.assertAlmostEqual(1.65, sess.run(update_op), 5) 229 230 self.assertAlmostEqual(1.65, sess.run(mean), 5) 231 232 def testUnweighted(self): 233 values = _test_values((3, 2, 4, 1)) 234 mean_results = ( 235 metrics.mean(values), 236 metrics.mean(values, weights=1.0), 237 metrics.mean(values, weights=np.ones((1, 1, 1))), 238 metrics.mean(values, weights=np.ones((1, 1, 1, 1))), 239 metrics.mean(values, weights=np.ones((1, 1, 1, 1, 1))), 240 metrics.mean(values, weights=np.ones((1, 1, 4))), 241 metrics.mean(values, weights=np.ones((1, 1, 4, 1))), 242 metrics.mean(values, weights=np.ones((1, 2, 1))), 243 metrics.mean(values, weights=np.ones((1, 2, 1, 1))), 244 metrics.mean(values, weights=np.ones((1, 2, 4))), 245 metrics.mean(values, weights=np.ones((1, 2, 4, 1))), 246 metrics.mean(values, weights=np.ones((3, 1, 1))), 247 metrics.mean(values, weights=np.ones((3, 1, 1, 1))), 248 metrics.mean(values, weights=np.ones((3, 1, 4))), 249 metrics.mean(values, weights=np.ones((3, 1, 4, 1))), 250 metrics.mean(values, weights=np.ones((3, 2, 1))), 251 metrics.mean(values, weights=np.ones((3, 2, 1, 1))), 252 metrics.mean(values, weights=np.ones((3, 2, 4))), 253 metrics.mean(values, weights=np.ones((3, 2, 4, 1))), 254 metrics.mean(values, weights=np.ones((3, 2, 4, 1, 1))),) 255 expected = np.mean(values) 256 with self.test_session(): 257 variables.local_variables_initializer().run() 258 for mean_result in mean_results: 259 mean, update_op = mean_result 260 self.assertAlmostEqual(expected, update_op.eval()) 261 self.assertAlmostEqual(expected, mean.eval()) 262 263 def _test_3d_weighted(self, values, weights): 264 expected = ( 265 np.sum(np.multiply(weights, values)) / 266 np.sum(np.multiply(weights, np.ones_like(values))) 267 ) 268 mean, update_op = metrics.mean(values, weights=weights) 269 with self.test_session(): 270 variables.local_variables_initializer().run() 271 self.assertAlmostEqual(expected, update_op.eval(), places=5) 272 self.assertAlmostEqual(expected, mean.eval(), places=5) 273 274 def test1x1x1Weighted(self): 275 self._test_3d_weighted( 276 _test_values((3, 2, 4)), 277 weights=np.asarray((5,)).reshape((1, 1, 1))) 278 279 def test1x1xNWeighted(self): 280 self._test_3d_weighted( 281 _test_values((3, 2, 4)), 282 weights=np.asarray((5, 7, 11, 3)).reshape((1, 1, 4))) 283 284 def test1xNx1Weighted(self): 285 self._test_3d_weighted( 286 _test_values((3, 2, 4)), 287 weights=np.asarray((5, 11)).reshape((1, 2, 1))) 288 289 def test1xNxNWeighted(self): 290 self._test_3d_weighted( 291 _test_values((3, 2, 4)), 292 weights=np.asarray((5, 7, 11, 3, 2, 13, 7, 5)).reshape((1, 2, 4))) 293 294 def testNx1x1Weighted(self): 295 self._test_3d_weighted( 296 _test_values((3, 2, 4)), 297 weights=np.asarray((5, 7, 11)).reshape((3, 1, 1))) 298 299 def testNx1xNWeighted(self): 300 self._test_3d_weighted( 301 _test_values((3, 2, 4)), 302 weights=np.asarray(( 303 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3)).reshape((3, 1, 4))) 304 305 def testNxNxNWeighted(self): 306 self._test_3d_weighted( 307 _test_values((3, 2, 4)), 308 weights=np.asarray(( 309 5, 7, 11, 3, 2, 12, 7, 5, 2, 17, 11, 3, 310 2, 17, 11, 3, 5, 7, 11, 3, 2, 12, 7, 5)).reshape((3, 2, 4))) 311 312 def testInvalidWeights(self): 313 values_placeholder = array_ops.placeholder(dtype=dtypes_lib.float32) 314 values = _test_values((3, 2, 4, 1)) 315 invalid_weights = ( 316 (1,), 317 (1, 1), 318 (3, 2), 319 (2, 4, 1), 320 (4, 2, 4, 1), 321 (3, 3, 4, 1), 322 (3, 2, 5, 1), 323 (3, 2, 4, 2), 324 (1, 1, 1, 1, 1)) 325 expected_error_msg = 'weights can not be broadcast to values' 326 for invalid_weight in invalid_weights: 327 # Static shapes. 328 with self.assertRaisesRegexp(ValueError, expected_error_msg): 329 metrics.mean(values, invalid_weight) 330 331 # Dynamic shapes. 332 with self.assertRaisesRegexp(errors_impl.OpError, expected_error_msg): 333 with self.test_session(): 334 _, update_op = metrics.mean(values_placeholder, invalid_weight) 335 variables.local_variables_initializer().run() 336 update_op.eval(feed_dict={values_placeholder: values}) 337 338 339 class MeanTensorTest(test.TestCase): 340 341 def setUp(self): 342 ops.reset_default_graph() 343 344 def testVars(self): 345 metrics.mean_tensor(array_ops.ones([4, 3])) 346 _assert_metric_variables(self, 347 ('mean/total_tensor:0', 'mean/count_tensor:0')) 348 349 def testMetricsCollection(self): 350 my_collection_name = '__metrics__' 351 mean, _ = metrics.mean_tensor( 352 array_ops.ones([4, 3]), metrics_collections=[my_collection_name]) 353 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 354 355 def testUpdatesCollection(self): 356 my_collection_name = '__updates__' 357 _, update_op = metrics.mean_tensor( 358 array_ops.ones([4, 3]), updates_collections=[my_collection_name]) 359 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 360 361 def testBasic(self): 362 with self.test_session() as sess: 363 values_queue = data_flow_ops.FIFOQueue( 364 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 365 _enqueue_vector(sess, values_queue, [0, 1]) 366 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 367 _enqueue_vector(sess, values_queue, [6.5, 0]) 368 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 369 values = values_queue.dequeue() 370 371 mean, update_op = metrics.mean_tensor(values) 372 373 sess.run(variables.local_variables_initializer()) 374 for _ in range(4): 375 sess.run(update_op) 376 self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean)) 377 378 def testMultiDimensional(self): 379 with self.test_session() as sess: 380 values_queue = data_flow_ops.FIFOQueue( 381 2, dtypes=dtypes_lib.float32, shapes=(2, 2, 2)) 382 _enqueue_vector( 383 sess, 384 values_queue, [[[1, 2], [1, 2]], [[1, 2], [1, 2]]], 385 shape=(2, 2, 2)) 386 _enqueue_vector( 387 sess, 388 values_queue, [[[1, 2], [1, 2]], [[3, 4], [9, 10]]], 389 shape=(2, 2, 2)) 390 values = values_queue.dequeue() 391 392 mean, update_op = metrics.mean_tensor(values) 393 394 sess.run(variables.local_variables_initializer()) 395 for _ in range(2): 396 sess.run(update_op) 397 self.assertAllClose([[[1, 2], [1, 2]], [[2, 3], [5, 6]]], sess.run(mean)) 398 399 def testUpdateOpsReturnsCurrentValue(self): 400 with self.test_session() as sess: 401 values_queue = data_flow_ops.FIFOQueue( 402 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 403 _enqueue_vector(sess, values_queue, [0, 1]) 404 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 405 _enqueue_vector(sess, values_queue, [6.5, 0]) 406 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 407 values = values_queue.dequeue() 408 409 mean, update_op = metrics.mean_tensor(values) 410 411 sess.run(variables.local_variables_initializer()) 412 413 self.assertAllClose([[0, 1]], sess.run(update_op), 5) 414 self.assertAllClose([[-2.1, 5.05]], sess.run(update_op), 5) 415 self.assertAllClose([[2.3 / 3., 10.1 / 3.]], sess.run(update_op), 5) 416 self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(update_op), 5) 417 418 self.assertAllClose([[-0.9 / 4., 3.525]], sess.run(mean), 5) 419 420 def testWeighted1d(self): 421 with self.test_session() as sess: 422 # Create the queue that populates the values. 423 values_queue = data_flow_ops.FIFOQueue( 424 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 425 _enqueue_vector(sess, values_queue, [0, 1]) 426 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 427 _enqueue_vector(sess, values_queue, [6.5, 0]) 428 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 429 values = values_queue.dequeue() 430 431 # Create the queue that populates the weights. 432 weights_queue = data_flow_ops.FIFOQueue( 433 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 434 _enqueue_vector(sess, weights_queue, [[1]]) 435 _enqueue_vector(sess, weights_queue, [[0]]) 436 _enqueue_vector(sess, weights_queue, [[1]]) 437 _enqueue_vector(sess, weights_queue, [[0]]) 438 weights = weights_queue.dequeue() 439 440 mean, update_op = metrics.mean_tensor(values, weights) 441 442 sess.run(variables.local_variables_initializer()) 443 for _ in range(4): 444 sess.run(update_op) 445 self.assertAllClose([[3.25, 0.5]], sess.run(mean), 5) 446 447 def testWeighted2d_1(self): 448 with self.test_session() as sess: 449 # Create the queue that populates the values. 450 values_queue = data_flow_ops.FIFOQueue( 451 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 452 _enqueue_vector(sess, values_queue, [0, 1]) 453 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 454 _enqueue_vector(sess, values_queue, [6.5, 0]) 455 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 456 values = values_queue.dequeue() 457 458 # Create the queue that populates the weights. 459 weights_queue = data_flow_ops.FIFOQueue( 460 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 461 _enqueue_vector(sess, weights_queue, [1, 1]) 462 _enqueue_vector(sess, weights_queue, [1, 0]) 463 _enqueue_vector(sess, weights_queue, [0, 1]) 464 _enqueue_vector(sess, weights_queue, [0, 0]) 465 weights = weights_queue.dequeue() 466 467 mean, update_op = metrics.mean_tensor(values, weights) 468 469 sess.run(variables.local_variables_initializer()) 470 for _ in range(4): 471 sess.run(update_op) 472 self.assertAllClose([[-2.1, 0.5]], sess.run(mean), 5) 473 474 def testWeighted2d_2(self): 475 with self.test_session() as sess: 476 # Create the queue that populates the values. 477 values_queue = data_flow_ops.FIFOQueue( 478 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 479 _enqueue_vector(sess, values_queue, [0, 1]) 480 _enqueue_vector(sess, values_queue, [-4.2, 9.1]) 481 _enqueue_vector(sess, values_queue, [6.5, 0]) 482 _enqueue_vector(sess, values_queue, [-3.2, 4.0]) 483 values = values_queue.dequeue() 484 485 # Create the queue that populates the weights. 486 weights_queue = data_flow_ops.FIFOQueue( 487 4, dtypes=dtypes_lib.float32, shapes=(1, 2)) 488 _enqueue_vector(sess, weights_queue, [0, 1]) 489 _enqueue_vector(sess, weights_queue, [0, 0]) 490 _enqueue_vector(sess, weights_queue, [0, 1]) 491 _enqueue_vector(sess, weights_queue, [0, 0]) 492 weights = weights_queue.dequeue() 493 494 mean, update_op = metrics.mean_tensor(values, weights) 495 496 sess.run(variables.local_variables_initializer()) 497 for _ in range(4): 498 sess.run(update_op) 499 self.assertAllClose([[0, 0.5]], sess.run(mean), 5) 500 501 502 class AccuracyTest(test.TestCase): 503 504 def setUp(self): 505 ops.reset_default_graph() 506 507 def testVars(self): 508 metrics.accuracy( 509 predictions=array_ops.ones((10, 1)), 510 labels=array_ops.ones((10, 1)), 511 name='my_accuracy') 512 _assert_metric_variables(self, 513 ('my_accuracy/count:0', 'my_accuracy/total:0')) 514 515 def testMetricsCollection(self): 516 my_collection_name = '__metrics__' 517 mean, _ = metrics.accuracy( 518 predictions=array_ops.ones((10, 1)), 519 labels=array_ops.ones((10, 1)), 520 metrics_collections=[my_collection_name]) 521 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 522 523 def testUpdatesCollection(self): 524 my_collection_name = '__updates__' 525 _, update_op = metrics.accuracy( 526 predictions=array_ops.ones((10, 1)), 527 labels=array_ops.ones((10, 1)), 528 updates_collections=[my_collection_name]) 529 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 530 531 def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self): 532 predictions = array_ops.ones((10, 3)) 533 labels = array_ops.ones((10, 4)) 534 with self.assertRaises(ValueError): 535 metrics.accuracy(labels, predictions) 536 537 def testPredictionsAndWeightsOfDifferentSizeRaisesValueError(self): 538 predictions = array_ops.ones((10, 3)) 539 labels = array_ops.ones((10, 3)) 540 weights = array_ops.ones((9, 3)) 541 with self.assertRaises(ValueError): 542 metrics.accuracy(labels, predictions, weights) 543 544 def testValueTensorIsIdempotent(self): 545 predictions = random_ops.random_uniform( 546 (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1) 547 labels = random_ops.random_uniform( 548 (10, 3), maxval=3, dtype=dtypes_lib.int64, seed=1) 549 accuracy, update_op = metrics.accuracy(labels, predictions) 550 551 with self.test_session() as sess: 552 sess.run(variables.local_variables_initializer()) 553 554 # Run several updates. 555 for _ in range(10): 556 sess.run(update_op) 557 558 # Then verify idempotency. 559 initial_accuracy = accuracy.eval() 560 for _ in range(10): 561 self.assertEqual(initial_accuracy, accuracy.eval()) 562 563 def testMultipleUpdates(self): 564 with self.test_session() as sess: 565 # Create the queue that populates the predictions. 566 preds_queue = data_flow_ops.FIFOQueue( 567 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 568 _enqueue_vector(sess, preds_queue, [0]) 569 _enqueue_vector(sess, preds_queue, [1]) 570 _enqueue_vector(sess, preds_queue, [2]) 571 _enqueue_vector(sess, preds_queue, [1]) 572 predictions = preds_queue.dequeue() 573 574 # Create the queue that populates the labels. 575 labels_queue = data_flow_ops.FIFOQueue( 576 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 577 _enqueue_vector(sess, labels_queue, [0]) 578 _enqueue_vector(sess, labels_queue, [1]) 579 _enqueue_vector(sess, labels_queue, [1]) 580 _enqueue_vector(sess, labels_queue, [2]) 581 labels = labels_queue.dequeue() 582 583 accuracy, update_op = metrics.accuracy(labels, predictions) 584 585 sess.run(variables.local_variables_initializer()) 586 for _ in xrange(3): 587 sess.run(update_op) 588 self.assertEqual(0.5, sess.run(update_op)) 589 self.assertEqual(0.5, accuracy.eval()) 590 591 def testEffectivelyEquivalentSizes(self): 592 predictions = array_ops.ones((40, 1)) 593 labels = array_ops.ones((40,)) 594 with self.test_session() as sess: 595 accuracy, update_op = metrics.accuracy(labels, predictions) 596 597 sess.run(variables.local_variables_initializer()) 598 self.assertEqual(1.0, update_op.eval()) 599 self.assertEqual(1.0, accuracy.eval()) 600 601 def testEffectivelyEquivalentSizesWithScalarWeight(self): 602 predictions = array_ops.ones((40, 1)) 603 labels = array_ops.ones((40,)) 604 with self.test_session() as sess: 605 accuracy, update_op = metrics.accuracy(labels, predictions, weights=2.0) 606 607 sess.run(variables.local_variables_initializer()) 608 self.assertEqual(1.0, update_op.eval()) 609 self.assertEqual(1.0, accuracy.eval()) 610 611 def testEffectivelyEquivalentSizesWithStaticShapedWeight(self): 612 predictions = ops.convert_to_tensor([1, 1, 1]) # shape 3, 613 labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]), 614 1) # shape 3, 1 615 weights = array_ops.expand_dims(ops.convert_to_tensor([100, 1, 1]), 616 1) # shape 3, 1 617 618 with self.test_session() as sess: 619 accuracy, update_op = metrics.accuracy(labels, predictions, weights) 620 621 sess.run(variables.local_variables_initializer()) 622 # if streaming_accuracy does not flatten the weight, accuracy would be 623 # 0.33333334 due to an intended broadcast of weight. Due to flattening, 624 # it will be higher than .95 625 self.assertGreater(update_op.eval(), .95) 626 self.assertGreater(accuracy.eval(), .95) 627 628 def testEffectivelyEquivalentSizesWithDynamicallyShapedWeight(self): 629 predictions = ops.convert_to_tensor([1, 1, 1]) # shape 3, 630 labels = array_ops.expand_dims(ops.convert_to_tensor([1, 0, 0]), 631 1) # shape 3, 1 632 633 weights = [[100], [1], [1]] # shape 3, 1 634 weights_placeholder = array_ops.placeholder( 635 dtype=dtypes_lib.int32, name='weights') 636 feed_dict = {weights_placeholder: weights} 637 638 with self.test_session() as sess: 639 accuracy, update_op = metrics.accuracy(labels, predictions, 640 weights_placeholder) 641 642 sess.run(variables.local_variables_initializer()) 643 # if streaming_accuracy does not flatten the weight, accuracy would be 644 # 0.33333334 due to an intended broadcast of weight. Due to flattening, 645 # it will be higher than .95 646 self.assertGreater(update_op.eval(feed_dict=feed_dict), .95) 647 self.assertGreater(accuracy.eval(feed_dict=feed_dict), .95) 648 649 def testMultipleUpdatesWithWeightedValues(self): 650 with self.test_session() as sess: 651 # Create the queue that populates the predictions. 652 preds_queue = data_flow_ops.FIFOQueue( 653 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 654 _enqueue_vector(sess, preds_queue, [0]) 655 _enqueue_vector(sess, preds_queue, [1]) 656 _enqueue_vector(sess, preds_queue, [2]) 657 _enqueue_vector(sess, preds_queue, [1]) 658 predictions = preds_queue.dequeue() 659 660 # Create the queue that populates the labels. 661 labels_queue = data_flow_ops.FIFOQueue( 662 4, dtypes=dtypes_lib.float32, shapes=(1, 1)) 663 _enqueue_vector(sess, labels_queue, [0]) 664 _enqueue_vector(sess, labels_queue, [1]) 665 _enqueue_vector(sess, labels_queue, [1]) 666 _enqueue_vector(sess, labels_queue, [2]) 667 labels = labels_queue.dequeue() 668 669 # Create the queue that populates the weights. 670 weights_queue = data_flow_ops.FIFOQueue( 671 4, dtypes=dtypes_lib.int64, shapes=(1, 1)) 672 _enqueue_vector(sess, weights_queue, [1]) 673 _enqueue_vector(sess, weights_queue, [1]) 674 _enqueue_vector(sess, weights_queue, [0]) 675 _enqueue_vector(sess, weights_queue, [0]) 676 weights = weights_queue.dequeue() 677 678 accuracy, update_op = metrics.accuracy(labels, predictions, weights) 679 680 sess.run(variables.local_variables_initializer()) 681 for _ in xrange(3): 682 sess.run(update_op) 683 self.assertEqual(1.0, sess.run(update_op)) 684 self.assertEqual(1.0, accuracy.eval()) 685 686 687 class PrecisionTest(test.TestCase): 688 689 def setUp(self): 690 np.random.seed(1) 691 ops.reset_default_graph() 692 693 def testVars(self): 694 metrics.precision( 695 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 696 _assert_metric_variables(self, ('precision/false_positives/count:0', 697 'precision/true_positives/count:0')) 698 699 def testMetricsCollection(self): 700 my_collection_name = '__metrics__' 701 mean, _ = metrics.precision( 702 predictions=array_ops.ones((10, 1)), 703 labels=array_ops.ones((10, 1)), 704 metrics_collections=[my_collection_name]) 705 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 706 707 def testUpdatesCollection(self): 708 my_collection_name = '__updates__' 709 _, update_op = metrics.precision( 710 predictions=array_ops.ones((10, 1)), 711 labels=array_ops.ones((10, 1)), 712 updates_collections=[my_collection_name]) 713 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 714 715 def testValueTensorIsIdempotent(self): 716 predictions = random_ops.random_uniform( 717 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 718 labels = random_ops.random_uniform( 719 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 720 precision, update_op = metrics.precision(labels, predictions) 721 722 with self.test_session() as sess: 723 sess.run(variables.local_variables_initializer()) 724 725 # Run several updates. 726 for _ in range(10): 727 sess.run(update_op) 728 729 # Then verify idempotency. 730 initial_precision = precision.eval() 731 for _ in range(10): 732 self.assertEqual(initial_precision, precision.eval()) 733 734 def testAllCorrect(self): 735 inputs = np.random.randint(0, 2, size=(100, 1)) 736 737 predictions = constant_op.constant(inputs) 738 labels = constant_op.constant(inputs) 739 precision, update_op = metrics.precision(labels, predictions) 740 741 with self.test_session() as sess: 742 sess.run(variables.local_variables_initializer()) 743 self.assertAlmostEqual(1, sess.run(update_op)) 744 self.assertAlmostEqual(1, precision.eval()) 745 746 def testSomeCorrect_multipleInputDtypes(self): 747 for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 748 predictions = math_ops.cast( 749 constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype) 750 labels = math_ops.cast( 751 constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype) 752 precision, update_op = metrics.precision(labels, predictions) 753 754 with self.test_session() as sess: 755 sess.run(variables.local_variables_initializer()) 756 self.assertAlmostEqual(0.5, update_op.eval()) 757 self.assertAlmostEqual(0.5, precision.eval()) 758 759 def testWeighted1d(self): 760 predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]]) 761 labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) 762 precision, update_op = metrics.precision( 763 labels, predictions, weights=constant_op.constant([[2], [5]])) 764 765 with self.test_session(): 766 variables.local_variables_initializer().run() 767 weighted_tp = 2.0 + 5.0 768 weighted_positives = (2.0 + 2.0) + (5.0 + 5.0) 769 expected_precision = weighted_tp / weighted_positives 770 self.assertAlmostEqual(expected_precision, update_op.eval()) 771 self.assertAlmostEqual(expected_precision, precision.eval()) 772 773 def testWeightedScalar_placeholders(self): 774 predictions = array_ops.placeholder(dtype=dtypes_lib.float32) 775 labels = array_ops.placeholder(dtype=dtypes_lib.float32) 776 feed_dict = { 777 predictions: ((1, 0, 1, 0), (1, 0, 1, 0)), 778 labels: ((0, 1, 1, 0), (1, 0, 0, 1)) 779 } 780 precision, update_op = metrics.precision(labels, predictions, weights=2) 781 782 with self.test_session(): 783 variables.local_variables_initializer().run() 784 weighted_tp = 2.0 + 2.0 785 weighted_positives = (2.0 + 2.0) + (2.0 + 2.0) 786 expected_precision = weighted_tp / weighted_positives 787 self.assertAlmostEqual( 788 expected_precision, update_op.eval(feed_dict=feed_dict)) 789 self.assertAlmostEqual( 790 expected_precision, precision.eval(feed_dict=feed_dict)) 791 792 def testWeighted1d_placeholders(self): 793 predictions = array_ops.placeholder(dtype=dtypes_lib.float32) 794 labels = array_ops.placeholder(dtype=dtypes_lib.float32) 795 feed_dict = { 796 predictions: ((1, 0, 1, 0), (1, 0, 1, 0)), 797 labels: ((0, 1, 1, 0), (1, 0, 0, 1)) 798 } 799 precision, update_op = metrics.precision( 800 labels, predictions, weights=constant_op.constant([[2], [5]])) 801 802 with self.test_session(): 803 variables.local_variables_initializer().run() 804 weighted_tp = 2.0 + 5.0 805 weighted_positives = (2.0 + 2.0) + (5.0 + 5.0) 806 expected_precision = weighted_tp / weighted_positives 807 self.assertAlmostEqual( 808 expected_precision, update_op.eval(feed_dict=feed_dict)) 809 self.assertAlmostEqual( 810 expected_precision, precision.eval(feed_dict=feed_dict)) 811 812 def testWeighted2d(self): 813 predictions = constant_op.constant([[1, 0, 1, 0], [1, 0, 1, 0]]) 814 labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) 815 precision, update_op = metrics.precision( 816 labels, 817 predictions, 818 weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])) 819 820 with self.test_session(): 821 variables.local_variables_initializer().run() 822 weighted_tp = 3.0 + 4.0 823 weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) 824 expected_precision = weighted_tp / weighted_positives 825 self.assertAlmostEqual(expected_precision, update_op.eval()) 826 self.assertAlmostEqual(expected_precision, precision.eval()) 827 828 def testWeighted2d_placeholders(self): 829 predictions = array_ops.placeholder(dtype=dtypes_lib.float32) 830 labels = array_ops.placeholder(dtype=dtypes_lib.float32) 831 feed_dict = { 832 predictions: ((1, 0, 1, 0), (1, 0, 1, 0)), 833 labels: ((0, 1, 1, 0), (1, 0, 0, 1)) 834 } 835 precision, update_op = metrics.precision( 836 labels, 837 predictions, 838 weights=constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]])) 839 840 with self.test_session(): 841 variables.local_variables_initializer().run() 842 weighted_tp = 3.0 + 4.0 843 weighted_positives = (1.0 + 3.0) + (4.0 + 2.0) 844 expected_precision = weighted_tp / weighted_positives 845 self.assertAlmostEqual( 846 expected_precision, update_op.eval(feed_dict=feed_dict)) 847 self.assertAlmostEqual( 848 expected_precision, precision.eval(feed_dict=feed_dict)) 849 850 def testAllIncorrect(self): 851 inputs = np.random.randint(0, 2, size=(100, 1)) 852 853 predictions = constant_op.constant(inputs) 854 labels = constant_op.constant(1 - inputs) 855 precision, update_op = metrics.precision(labels, predictions) 856 857 with self.test_session() as sess: 858 sess.run(variables.local_variables_initializer()) 859 sess.run(update_op) 860 self.assertAlmostEqual(0, precision.eval()) 861 862 def testZeroTrueAndFalsePositivesGivesZeroPrecision(self): 863 predictions = constant_op.constant([0, 0, 0, 0]) 864 labels = constant_op.constant([0, 0, 0, 0]) 865 precision, update_op = metrics.precision(labels, predictions) 866 867 with self.test_session() as sess: 868 sess.run(variables.local_variables_initializer()) 869 sess.run(update_op) 870 self.assertEqual(0.0, precision.eval()) 871 872 873 class RecallTest(test.TestCase): 874 875 def setUp(self): 876 np.random.seed(1) 877 ops.reset_default_graph() 878 879 def testVars(self): 880 metrics.recall( 881 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 882 _assert_metric_variables( 883 self, 884 ('recall/false_negatives/count:0', 'recall/true_positives/count:0')) 885 886 def testMetricsCollection(self): 887 my_collection_name = '__metrics__' 888 mean, _ = metrics.recall( 889 predictions=array_ops.ones((10, 1)), 890 labels=array_ops.ones((10, 1)), 891 metrics_collections=[my_collection_name]) 892 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 893 894 def testUpdatesCollection(self): 895 my_collection_name = '__updates__' 896 _, update_op = metrics.recall( 897 predictions=array_ops.ones((10, 1)), 898 labels=array_ops.ones((10, 1)), 899 updates_collections=[my_collection_name]) 900 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 901 902 def testValueTensorIsIdempotent(self): 903 predictions = random_ops.random_uniform( 904 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 905 labels = random_ops.random_uniform( 906 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 907 recall, update_op = metrics.recall(labels, predictions) 908 909 with self.test_session() as sess: 910 sess.run(variables.local_variables_initializer()) 911 912 # Run several updates. 913 for _ in range(10): 914 sess.run(update_op) 915 916 # Then verify idempotency. 917 initial_recall = recall.eval() 918 for _ in range(10): 919 self.assertEqual(initial_recall, recall.eval()) 920 921 def testAllCorrect(self): 922 np_inputs = np.random.randint(0, 2, size=(100, 1)) 923 924 predictions = constant_op.constant(np_inputs) 925 labels = constant_op.constant(np_inputs) 926 recall, update_op = metrics.recall(labels, predictions) 927 928 with self.test_session() as sess: 929 sess.run(variables.local_variables_initializer()) 930 sess.run(update_op) 931 self.assertEqual(1, recall.eval()) 932 933 def testSomeCorrect_multipleInputDtypes(self): 934 for dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 935 predictions = math_ops.cast( 936 constant_op.constant([1, 0, 1, 0], shape=(1, 4)), dtype=dtype) 937 labels = math_ops.cast( 938 constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=dtype) 939 recall, update_op = metrics.recall(labels, predictions) 940 941 with self.test_session() as sess: 942 sess.run(variables.local_variables_initializer()) 943 self.assertAlmostEqual(0.5, update_op.eval()) 944 self.assertAlmostEqual(0.5, recall.eval()) 945 946 def testWeighted1d(self): 947 predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) 948 labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) 949 weights = constant_op.constant([[2], [5]]) 950 recall, update_op = metrics.recall(labels, predictions, weights=weights) 951 952 with self.test_session() as sess: 953 sess.run(variables.local_variables_initializer()) 954 weighted_tp = 2.0 + 5.0 955 weighted_t = (2.0 + 2.0) + (5.0 + 5.0) 956 expected_precision = weighted_tp / weighted_t 957 self.assertAlmostEqual(expected_precision, update_op.eval()) 958 self.assertAlmostEqual(expected_precision, recall.eval()) 959 960 def testWeighted2d(self): 961 predictions = constant_op.constant([[1, 0, 1, 0], [0, 1, 0, 1]]) 962 labels = constant_op.constant([[0, 1, 1, 0], [1, 0, 0, 1]]) 963 weights = constant_op.constant([[1, 2, 3, 4], [4, 3, 2, 1]]) 964 recall, update_op = metrics.recall(labels, predictions, weights=weights) 965 966 with self.test_session() as sess: 967 sess.run(variables.local_variables_initializer()) 968 weighted_tp = 3.0 + 1.0 969 weighted_t = (2.0 + 3.0) + (4.0 + 1.0) 970 expected_precision = weighted_tp / weighted_t 971 self.assertAlmostEqual(expected_precision, update_op.eval()) 972 self.assertAlmostEqual(expected_precision, recall.eval()) 973 974 def testAllIncorrect(self): 975 np_inputs = np.random.randint(0, 2, size=(100, 1)) 976 977 predictions = constant_op.constant(np_inputs) 978 labels = constant_op.constant(1 - np_inputs) 979 recall, update_op = metrics.recall(labels, predictions) 980 981 with self.test_session() as sess: 982 sess.run(variables.local_variables_initializer()) 983 sess.run(update_op) 984 self.assertEqual(0, recall.eval()) 985 986 def testZeroTruePositivesAndFalseNegativesGivesZeroRecall(self): 987 predictions = array_ops.zeros((1, 4)) 988 labels = array_ops.zeros((1, 4)) 989 recall, update_op = metrics.recall(labels, predictions) 990 991 with self.test_session() as sess: 992 sess.run(variables.local_variables_initializer()) 993 sess.run(update_op) 994 self.assertEqual(0, recall.eval()) 995 996 997 class AUCTest(test.TestCase): 998 999 def setUp(self): 1000 np.random.seed(1) 1001 ops.reset_default_graph() 1002 1003 def testVars(self): 1004 metrics.auc(predictions=array_ops.ones((10, 1)), 1005 labels=array_ops.ones((10, 1))) 1006 _assert_metric_variables(self, 1007 ('auc/true_positives:0', 'auc/false_negatives:0', 1008 'auc/false_positives:0', 'auc/true_negatives:0')) 1009 1010 def testMetricsCollection(self): 1011 my_collection_name = '__metrics__' 1012 mean, _ = metrics.auc(predictions=array_ops.ones((10, 1)), 1013 labels=array_ops.ones((10, 1)), 1014 metrics_collections=[my_collection_name]) 1015 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 1016 1017 def testUpdatesCollection(self): 1018 my_collection_name = '__updates__' 1019 _, update_op = metrics.auc(predictions=array_ops.ones((10, 1)), 1020 labels=array_ops.ones((10, 1)), 1021 updates_collections=[my_collection_name]) 1022 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 1023 1024 def testValueTensorIsIdempotent(self): 1025 predictions = random_ops.random_uniform( 1026 (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) 1027 labels = random_ops.random_uniform( 1028 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 1029 auc, update_op = metrics.auc(labels, predictions) 1030 1031 with self.test_session() as sess: 1032 sess.run(variables.local_variables_initializer()) 1033 1034 # Run several updates. 1035 for _ in range(10): 1036 sess.run(update_op) 1037 1038 # Then verify idempotency. 1039 initial_auc = auc.eval() 1040 for _ in range(10): 1041 self.assertAlmostEqual(initial_auc, auc.eval(), 5) 1042 1043 def testAllCorrect(self): 1044 self.allCorrectAsExpected('ROC') 1045 1046 def allCorrectAsExpected(self, curve): 1047 inputs = np.random.randint(0, 2, size=(100, 1)) 1048 1049 with self.test_session() as sess: 1050 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1051 labels = constant_op.constant(inputs) 1052 auc, update_op = metrics.auc(labels, predictions, curve=curve) 1053 1054 sess.run(variables.local_variables_initializer()) 1055 self.assertEqual(1, sess.run(update_op)) 1056 1057 self.assertEqual(1, auc.eval()) 1058 1059 def testSomeCorrect_multipleLabelDtypes(self): 1060 with self.test_session() as sess: 1061 for label_dtype in ( 1062 dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1063 predictions = constant_op.constant( 1064 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1065 labels = math_ops.cast( 1066 constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype) 1067 auc, update_op = metrics.auc(labels, predictions) 1068 1069 sess.run(variables.local_variables_initializer()) 1070 self.assertAlmostEqual(0.5, sess.run(update_op)) 1071 1072 self.assertAlmostEqual(0.5, auc.eval()) 1073 1074 def testWeighted1d(self): 1075 with self.test_session() as sess: 1076 predictions = constant_op.constant( 1077 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1078 labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) 1079 weights = constant_op.constant([2], shape=(1, 1)) 1080 auc, update_op = metrics.auc(labels, predictions, weights=weights) 1081 1082 sess.run(variables.local_variables_initializer()) 1083 self.assertAlmostEqual(0.5, sess.run(update_op), 5) 1084 1085 self.assertAlmostEqual(0.5, auc.eval(), 5) 1086 1087 def testWeighted2d(self): 1088 with self.test_session() as sess: 1089 predictions = constant_op.constant( 1090 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1091 labels = constant_op.constant([0, 1, 1, 0], shape=(1, 4)) 1092 weights = constant_op.constant([1, 2, 3, 4], shape=(1, 4)) 1093 auc, update_op = metrics.auc(labels, predictions, weights=weights) 1094 1095 sess.run(variables.local_variables_initializer()) 1096 self.assertAlmostEqual(0.7, sess.run(update_op), 5) 1097 1098 self.assertAlmostEqual(0.7, auc.eval(), 5) 1099 1100 def testAUCPRSpecialCase(self): 1101 with self.test_session() as sess: 1102 predictions = constant_op.constant( 1103 [0.1, 0.4, 0.35, 0.8], shape=(1, 4), dtype=dtypes_lib.float32) 1104 labels = constant_op.constant([0, 0, 1, 1], shape=(1, 4)) 1105 auc, update_op = metrics.auc(labels, predictions, curve='PR') 1106 1107 sess.run(variables.local_variables_initializer()) 1108 self.assertAlmostEqual(0.54166, sess.run(update_op), delta=1e-3) 1109 1110 self.assertAlmostEqual(0.54166, auc.eval(), delta=1e-3) 1111 1112 def testAnotherAUCPRSpecialCase(self): 1113 with self.test_session() as sess: 1114 predictions = constant_op.constant( 1115 [0.1, 0.4, 0.35, 0.8, 0.1, 0.135, 0.81], 1116 shape=(1, 7), 1117 dtype=dtypes_lib.float32) 1118 labels = constant_op.constant([0, 0, 1, 0, 1, 0, 1], shape=(1, 7)) 1119 auc, update_op = metrics.auc(labels, predictions, curve='PR') 1120 1121 sess.run(variables.local_variables_initializer()) 1122 self.assertAlmostEqual(0.44365042, sess.run(update_op), delta=1e-3) 1123 1124 self.assertAlmostEqual(0.44365042, auc.eval(), delta=1e-3) 1125 1126 def testThirdAUCPRSpecialCase(self): 1127 with self.test_session() as sess: 1128 predictions = constant_op.constant( 1129 [0.0, 0.1, 0.2, 0.33, 0.3, 0.4, 0.5], 1130 shape=(1, 7), 1131 dtype=dtypes_lib.float32) 1132 labels = constant_op.constant([0, 0, 0, 0, 1, 1, 1], shape=(1, 7)) 1133 auc, update_op = metrics.auc(labels, predictions, curve='PR') 1134 1135 sess.run(variables.local_variables_initializer()) 1136 self.assertAlmostEqual(0.73611039, sess.run(update_op), delta=1e-3) 1137 1138 self.assertAlmostEqual(0.73611039, auc.eval(), delta=1e-3) 1139 1140 def testFourthAUCPRSpecialCase(self): 1141 # Create the labels and data. 1142 labels = np.array([ 1143 0, 0, 0, 0, 0, 0, 0, 1, 0, 1]) 1144 predictions = np.array([ 1145 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35, 0.35]) 1146 1147 with self.test_session() as sess: 1148 auc, _ = metrics.auc( 1149 labels, predictions, curve='PR', num_thresholds=11) 1150 1151 sess.run(variables.local_variables_initializer()) 1152 # Since this is only approximate, we can't expect a 6 digits match. 1153 # Although with higher number of samples/thresholds we should see the 1154 # accuracy improving 1155 self.assertAlmostEqual(0.0, auc.eval(), delta=0.001) 1156 1157 def testAllIncorrect(self): 1158 inputs = np.random.randint(0, 2, size=(100, 1)) 1159 1160 with self.test_session() as sess: 1161 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1162 labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) 1163 auc, update_op = metrics.auc(labels, predictions) 1164 1165 sess.run(variables.local_variables_initializer()) 1166 self.assertAlmostEqual(0, sess.run(update_op)) 1167 1168 self.assertAlmostEqual(0, auc.eval()) 1169 1170 def testZeroTruePositivesAndFalseNegativesGivesOneAUC(self): 1171 with self.test_session() as sess: 1172 predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) 1173 labels = array_ops.zeros([4]) 1174 auc, update_op = metrics.auc(labels, predictions) 1175 1176 sess.run(variables.local_variables_initializer()) 1177 self.assertAlmostEqual(1, sess.run(update_op), 6) 1178 1179 self.assertAlmostEqual(1, auc.eval(), 6) 1180 1181 def testRecallOneAndPrecisionOne(self): 1182 with self.test_session() as sess: 1183 predictions = array_ops.ones([4], dtype=dtypes_lib.float32) 1184 labels = array_ops.ones([4]) 1185 auc, update_op = metrics.auc(labels, predictions, curve='PR') 1186 1187 sess.run(variables.local_variables_initializer()) 1188 self.assertAlmostEqual(0.5, sess.run(update_op), 6) 1189 1190 self.assertAlmostEqual(0.5, auc.eval(), 6) 1191 1192 def np_auc(self, predictions, labels, weights): 1193 """Computes the AUC explicitly using Numpy. 1194 1195 Args: 1196 predictions: an ndarray with shape [N]. 1197 labels: an ndarray with shape [N]. 1198 weights: an ndarray with shape [N]. 1199 1200 Returns: 1201 the area under the ROC curve. 1202 """ 1203 if weights is None: 1204 weights = np.ones(np.size(predictions)) 1205 is_positive = labels > 0 1206 num_positives = np.sum(weights[is_positive]) 1207 num_negatives = np.sum(weights[~is_positive]) 1208 1209 # Sort descending: 1210 inds = np.argsort(-predictions) 1211 1212 sorted_labels = labels[inds] 1213 sorted_weights = weights[inds] 1214 is_positive = sorted_labels > 0 1215 1216 tp = np.cumsum(sorted_weights * is_positive) / num_positives 1217 return np.sum((sorted_weights * tp)[~is_positive]) / num_negatives 1218 1219 def testWithMultipleUpdates(self): 1220 num_samples = 1000 1221 batch_size = 10 1222 num_batches = int(num_samples / batch_size) 1223 1224 # Create the labels and data. 1225 labels = np.random.randint(0, 2, size=num_samples) 1226 noise = np.random.normal(0.0, scale=0.2, size=num_samples) 1227 predictions = 0.4 + 0.2 * labels + noise 1228 predictions[predictions > 1] = 1 1229 predictions[predictions < 0] = 0 1230 1231 def _enqueue_as_batches(x, enqueue_ops): 1232 x_batches = x.astype(np.float32).reshape((num_batches, batch_size)) 1233 x_queue = data_flow_ops.FIFOQueue( 1234 num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) 1235 for i in range(num_batches): 1236 enqueue_ops[i].append(x_queue.enqueue(x_batches[i, :])) 1237 return x_queue.dequeue() 1238 1239 for weights in (None, np.ones(num_samples), np.random.exponential( 1240 scale=1.0, size=num_samples)): 1241 expected_auc = self.np_auc(predictions, labels, weights) 1242 1243 with self.test_session() as sess: 1244 enqueue_ops = [[] for i in range(num_batches)] 1245 tf_predictions = _enqueue_as_batches(predictions, enqueue_ops) 1246 tf_labels = _enqueue_as_batches(labels, enqueue_ops) 1247 tf_weights = (_enqueue_as_batches(weights, enqueue_ops) if 1248 weights is not None else None) 1249 1250 for i in range(num_batches): 1251 sess.run(enqueue_ops[i]) 1252 1253 auc, update_op = metrics.auc(tf_labels, 1254 tf_predictions, 1255 curve='ROC', 1256 num_thresholds=500, 1257 weights=tf_weights) 1258 1259 sess.run(variables.local_variables_initializer()) 1260 for i in range(num_batches): 1261 sess.run(update_op) 1262 1263 # Since this is only approximate, we can't expect a 6 digits match. 1264 # Although with higher number of samples/thresholds we should see the 1265 # accuracy improving 1266 self.assertAlmostEqual(expected_auc, auc.eval(), 2) 1267 1268 1269 class SpecificityAtSensitivityTest(test.TestCase): 1270 1271 def setUp(self): 1272 np.random.seed(1) 1273 ops.reset_default_graph() 1274 1275 def testVars(self): 1276 metrics.specificity_at_sensitivity( 1277 predictions=array_ops.ones((10, 1)), 1278 labels=array_ops.ones((10, 1)), 1279 sensitivity=0.7) 1280 _assert_metric_variables(self, 1281 ('specificity_at_sensitivity/true_positives:0', 1282 'specificity_at_sensitivity/false_negatives:0', 1283 'specificity_at_sensitivity/false_positives:0', 1284 'specificity_at_sensitivity/true_negatives:0')) 1285 1286 def testMetricsCollection(self): 1287 my_collection_name = '__metrics__' 1288 mean, _ = metrics.specificity_at_sensitivity( 1289 predictions=array_ops.ones((10, 1)), 1290 labels=array_ops.ones((10, 1)), 1291 sensitivity=0.7, 1292 metrics_collections=[my_collection_name]) 1293 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 1294 1295 def testUpdatesCollection(self): 1296 my_collection_name = '__updates__' 1297 _, update_op = metrics.specificity_at_sensitivity( 1298 predictions=array_ops.ones((10, 1)), 1299 labels=array_ops.ones((10, 1)), 1300 sensitivity=0.7, 1301 updates_collections=[my_collection_name]) 1302 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 1303 1304 def testValueTensorIsIdempotent(self): 1305 predictions = random_ops.random_uniform( 1306 (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) 1307 labels = random_ops.random_uniform( 1308 (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1) 1309 specificity, update_op = metrics.specificity_at_sensitivity( 1310 labels, predictions, sensitivity=0.7) 1311 1312 with self.test_session() as sess: 1313 sess.run(variables.local_variables_initializer()) 1314 1315 # Run several updates. 1316 for _ in range(10): 1317 sess.run(update_op) 1318 1319 # Then verify idempotency. 1320 initial_specificity = specificity.eval() 1321 for _ in range(10): 1322 self.assertAlmostEqual(initial_specificity, specificity.eval(), 5) 1323 1324 def testAllCorrect(self): 1325 inputs = np.random.randint(0, 2, size=(100, 1)) 1326 1327 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1328 labels = constant_op.constant(inputs) 1329 specificity, update_op = metrics.specificity_at_sensitivity( 1330 labels, predictions, sensitivity=0.7) 1331 1332 with self.test_session() as sess: 1333 sess.run(variables.local_variables_initializer()) 1334 self.assertEqual(1, sess.run(update_op)) 1335 self.assertEqual(1, specificity.eval()) 1336 1337 def testSomeCorrectHighSensitivity(self): 1338 predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.45, 0.5, 0.8, 0.9] 1339 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1340 1341 predictions = constant_op.constant( 1342 predictions_values, dtype=dtypes_lib.float32) 1343 labels = constant_op.constant(labels_values) 1344 specificity, update_op = metrics.specificity_at_sensitivity( 1345 labels, predictions, sensitivity=0.8) 1346 1347 with self.test_session() as sess: 1348 sess.run(variables.local_variables_initializer()) 1349 self.assertAlmostEqual(1.0, sess.run(update_op)) 1350 self.assertAlmostEqual(1.0, specificity.eval()) 1351 1352 def testSomeCorrectLowSensitivity(self): 1353 predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26] 1354 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1355 1356 predictions = constant_op.constant( 1357 predictions_values, dtype=dtypes_lib.float32) 1358 labels = constant_op.constant(labels_values) 1359 specificity, update_op = metrics.specificity_at_sensitivity( 1360 labels, predictions, sensitivity=0.4) 1361 1362 with self.test_session() as sess: 1363 sess.run(variables.local_variables_initializer()) 1364 1365 self.assertAlmostEqual(0.6, sess.run(update_op)) 1366 self.assertAlmostEqual(0.6, specificity.eval()) 1367 1368 def testWeighted1d_multipleLabelDtypes(self): 1369 for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1370 predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26] 1371 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1372 weights_values = [3] 1373 1374 predictions = constant_op.constant( 1375 predictions_values, dtype=dtypes_lib.float32) 1376 labels = math_ops.cast(labels_values, dtype=label_dtype) 1377 weights = constant_op.constant(weights_values) 1378 specificity, update_op = metrics.specificity_at_sensitivity( 1379 labels, predictions, weights=weights, sensitivity=0.4) 1380 1381 with self.test_session() as sess: 1382 sess.run(variables.local_variables_initializer()) 1383 1384 self.assertAlmostEqual(0.6, sess.run(update_op)) 1385 self.assertAlmostEqual(0.6, specificity.eval()) 1386 1387 def testWeighted2d(self): 1388 predictions_values = [0.1, 0.2, 0.4, 0.3, 0.0, 0.1, 0.2, 0.2, 0.26, 0.26] 1389 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1390 weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1391 1392 predictions = constant_op.constant( 1393 predictions_values, dtype=dtypes_lib.float32) 1394 labels = constant_op.constant(labels_values) 1395 weights = constant_op.constant(weights_values) 1396 specificity, update_op = metrics.specificity_at_sensitivity( 1397 labels, predictions, weights=weights, sensitivity=0.4) 1398 1399 with self.test_session() as sess: 1400 sess.run(variables.local_variables_initializer()) 1401 1402 self.assertAlmostEqual(8.0 / 15.0, sess.run(update_op)) 1403 self.assertAlmostEqual(8.0 / 15.0, specificity.eval()) 1404 1405 1406 class SensitivityAtSpecificityTest(test.TestCase): 1407 1408 def setUp(self): 1409 np.random.seed(1) 1410 ops.reset_default_graph() 1411 1412 def testVars(self): 1413 metrics.sensitivity_at_specificity( 1414 predictions=array_ops.ones((10, 1)), 1415 labels=array_ops.ones((10, 1)), 1416 specificity=0.7) 1417 _assert_metric_variables(self, 1418 ('sensitivity_at_specificity/true_positives:0', 1419 'sensitivity_at_specificity/false_negatives:0', 1420 'sensitivity_at_specificity/false_positives:0', 1421 'sensitivity_at_specificity/true_negatives:0')) 1422 1423 def testMetricsCollection(self): 1424 my_collection_name = '__metrics__' 1425 mean, _ = metrics.sensitivity_at_specificity( 1426 predictions=array_ops.ones((10, 1)), 1427 labels=array_ops.ones((10, 1)), 1428 specificity=0.7, 1429 metrics_collections=[my_collection_name]) 1430 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 1431 1432 def testUpdatesCollection(self): 1433 my_collection_name = '__updates__' 1434 _, update_op = metrics.sensitivity_at_specificity( 1435 predictions=array_ops.ones((10, 1)), 1436 labels=array_ops.ones((10, 1)), 1437 specificity=0.7, 1438 updates_collections=[my_collection_name]) 1439 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 1440 1441 def testValueTensorIsIdempotent(self): 1442 predictions = random_ops.random_uniform( 1443 (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) 1444 labels = random_ops.random_uniform( 1445 (10, 3), maxval=2, dtype=dtypes_lib.int64, seed=1) 1446 sensitivity, update_op = metrics.sensitivity_at_specificity( 1447 labels, predictions, specificity=0.7) 1448 1449 with self.test_session() as sess: 1450 sess.run(variables.local_variables_initializer()) 1451 1452 # Run several updates. 1453 for _ in range(10): 1454 sess.run(update_op) 1455 1456 # Then verify idempotency. 1457 initial_sensitivity = sensitivity.eval() 1458 for _ in range(10): 1459 self.assertAlmostEqual(initial_sensitivity, sensitivity.eval(), 5) 1460 1461 def testAllCorrect(self): 1462 inputs = np.random.randint(0, 2, size=(100, 1)) 1463 1464 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1465 labels = constant_op.constant(inputs) 1466 specificity, update_op = metrics.sensitivity_at_specificity( 1467 labels, predictions, specificity=0.7) 1468 1469 with self.test_session() as sess: 1470 sess.run(variables.local_variables_initializer()) 1471 self.assertEqual(1, sess.run(update_op)) 1472 self.assertEqual(1, specificity.eval()) 1473 1474 def testSomeCorrectHighSpecificity(self): 1475 predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.1, 0.45, 0.5, 0.8, 0.9] 1476 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1477 1478 predictions = constant_op.constant( 1479 predictions_values, dtype=dtypes_lib.float32) 1480 labels = constant_op.constant(labels_values) 1481 specificity, update_op = metrics.sensitivity_at_specificity( 1482 labels, predictions, specificity=0.8) 1483 1484 with self.test_session() as sess: 1485 sess.run(variables.local_variables_initializer()) 1486 self.assertAlmostEqual(0.8, sess.run(update_op)) 1487 self.assertAlmostEqual(0.8, specificity.eval()) 1488 1489 def testSomeCorrectLowSpecificity(self): 1490 predictions_values = [0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] 1491 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1492 1493 predictions = constant_op.constant( 1494 predictions_values, dtype=dtypes_lib.float32) 1495 labels = constant_op.constant(labels_values) 1496 specificity, update_op = metrics.sensitivity_at_specificity( 1497 labels, predictions, specificity=0.4) 1498 1499 with self.test_session() as sess: 1500 sess.run(variables.local_variables_initializer()) 1501 self.assertAlmostEqual(0.6, sess.run(update_op)) 1502 self.assertAlmostEqual(0.6, specificity.eval()) 1503 1504 def testWeighted_multipleLabelDtypes(self): 1505 for label_dtype in (dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1506 predictions_values = [ 1507 0.0, 0.1, 0.2, 0.3, 0.4, 0.01, 0.02, 0.25, 0.26, 0.26] 1508 labels_values = [0, 0, 0, 0, 0, 1, 1, 1, 1, 1] 1509 weights_values = [1, 2, 3, 4, 5, 6, 7, 8, 9, 10] 1510 1511 predictions = constant_op.constant( 1512 predictions_values, dtype=dtypes_lib.float32) 1513 labels = math_ops.cast(labels_values, dtype=label_dtype) 1514 weights = constant_op.constant(weights_values) 1515 specificity, update_op = metrics.sensitivity_at_specificity( 1516 labels, predictions, weights=weights, specificity=0.4) 1517 1518 with self.test_session() as sess: 1519 sess.run(variables.local_variables_initializer()) 1520 self.assertAlmostEqual(0.675, sess.run(update_op)) 1521 self.assertAlmostEqual(0.675, specificity.eval()) 1522 1523 1524 # TODO(nsilberman): Break this up into two sets of tests. 1525 class PrecisionRecallThresholdsTest(test.TestCase): 1526 1527 def setUp(self): 1528 np.random.seed(1) 1529 ops.reset_default_graph() 1530 1531 def testVars(self): 1532 metrics.precision_at_thresholds( 1533 predictions=array_ops.ones((10, 1)), 1534 labels=array_ops.ones((10, 1)), 1535 thresholds=[0, 0.5, 1.0]) 1536 _assert_metric_variables(self, ( 1537 'precision_at_thresholds/true_positives:0', 1538 'precision_at_thresholds/false_positives:0', 1539 )) 1540 1541 def testMetricsCollection(self): 1542 my_collection_name = '__metrics__' 1543 prec, _ = metrics.precision_at_thresholds( 1544 predictions=array_ops.ones((10, 1)), 1545 labels=array_ops.ones((10, 1)), 1546 thresholds=[0, 0.5, 1.0], 1547 metrics_collections=[my_collection_name]) 1548 rec, _ = metrics.recall_at_thresholds( 1549 predictions=array_ops.ones((10, 1)), 1550 labels=array_ops.ones((10, 1)), 1551 thresholds=[0, 0.5, 1.0], 1552 metrics_collections=[my_collection_name]) 1553 self.assertListEqual(ops.get_collection(my_collection_name), [prec, rec]) 1554 1555 def testUpdatesCollection(self): 1556 my_collection_name = '__updates__' 1557 _, precision_op = metrics.precision_at_thresholds( 1558 predictions=array_ops.ones((10, 1)), 1559 labels=array_ops.ones((10, 1)), 1560 thresholds=[0, 0.5, 1.0], 1561 updates_collections=[my_collection_name]) 1562 _, recall_op = metrics.recall_at_thresholds( 1563 predictions=array_ops.ones((10, 1)), 1564 labels=array_ops.ones((10, 1)), 1565 thresholds=[0, 0.5, 1.0], 1566 updates_collections=[my_collection_name]) 1567 self.assertListEqual( 1568 ops.get_collection(my_collection_name), [precision_op, recall_op]) 1569 1570 def testValueTensorIsIdempotent(self): 1571 predictions = random_ops.random_uniform( 1572 (10, 3), maxval=1, dtype=dtypes_lib.float32, seed=1) 1573 labels = random_ops.random_uniform( 1574 (10, 3), maxval=1, dtype=dtypes_lib.int64, seed=1) 1575 thresholds = [0, 0.5, 1.0] 1576 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1577 thresholds) 1578 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, thresholds) 1579 1580 with self.test_session() as sess: 1581 sess.run(variables.local_variables_initializer()) 1582 1583 # Run several updates, then verify idempotency. 1584 sess.run([prec_op, rec_op]) 1585 initial_prec = prec.eval() 1586 initial_rec = rec.eval() 1587 for _ in range(10): 1588 sess.run([prec_op, rec_op]) 1589 self.assertAllClose(initial_prec, prec.eval()) 1590 self.assertAllClose(initial_rec, rec.eval()) 1591 1592 # TODO(nsilberman): fix tests (passing but incorrect). 1593 def testAllCorrect(self): 1594 inputs = np.random.randint(0, 2, size=(100, 1)) 1595 1596 with self.test_session() as sess: 1597 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1598 labels = constant_op.constant(inputs) 1599 thresholds = [0.5] 1600 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1601 thresholds) 1602 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1603 thresholds) 1604 1605 sess.run(variables.local_variables_initializer()) 1606 sess.run([prec_op, rec_op]) 1607 1608 self.assertEqual(1, prec.eval()) 1609 self.assertEqual(1, rec.eval()) 1610 1611 def testSomeCorrect_multipleLabelDtypes(self): 1612 with self.test_session() as sess: 1613 for label_dtype in ( 1614 dtypes_lib.bool, dtypes_lib.int32, dtypes_lib.float32): 1615 predictions = constant_op.constant( 1616 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1617 labels = math_ops.cast( 1618 constant_op.constant([0, 1, 1, 0], shape=(1, 4)), dtype=label_dtype) 1619 thresholds = [0.5] 1620 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1621 thresholds) 1622 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1623 thresholds) 1624 1625 sess.run(variables.local_variables_initializer()) 1626 sess.run([prec_op, rec_op]) 1627 1628 self.assertAlmostEqual(0.5, prec.eval()) 1629 self.assertAlmostEqual(0.5, rec.eval()) 1630 1631 def testAllIncorrect(self): 1632 inputs = np.random.randint(0, 2, size=(100, 1)) 1633 1634 with self.test_session() as sess: 1635 predictions = constant_op.constant(inputs, dtype=dtypes_lib.float32) 1636 labels = constant_op.constant(1 - inputs, dtype=dtypes_lib.float32) 1637 thresholds = [0.5] 1638 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1639 thresholds) 1640 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1641 thresholds) 1642 1643 sess.run(variables.local_variables_initializer()) 1644 sess.run([prec_op, rec_op]) 1645 1646 self.assertAlmostEqual(0, prec.eval()) 1647 self.assertAlmostEqual(0, rec.eval()) 1648 1649 def testWeights1d(self): 1650 with self.test_session() as sess: 1651 predictions = constant_op.constant( 1652 [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) 1653 labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) 1654 weights = constant_op.constant( 1655 [[0], [1]], shape=(2, 1), dtype=dtypes_lib.float32) 1656 thresholds = [0.5, 1.1] 1657 prec, prec_op = metrics.precision_at_thresholds( 1658 labels, predictions, thresholds, weights=weights) 1659 rec, rec_op = metrics.recall_at_thresholds( 1660 labels, predictions, thresholds, weights=weights) 1661 1662 [prec_low, prec_high] = array_ops.split( 1663 value=prec, num_or_size_splits=2, axis=0) 1664 prec_low = array_ops.reshape(prec_low, shape=()) 1665 prec_high = array_ops.reshape(prec_high, shape=()) 1666 [rec_low, rec_high] = array_ops.split( 1667 value=rec, num_or_size_splits=2, axis=0) 1668 rec_low = array_ops.reshape(rec_low, shape=()) 1669 rec_high = array_ops.reshape(rec_high, shape=()) 1670 1671 sess.run(variables.local_variables_initializer()) 1672 sess.run([prec_op, rec_op]) 1673 1674 self.assertAlmostEqual(1.0, prec_low.eval(), places=5) 1675 self.assertAlmostEqual(0.0, prec_high.eval(), places=5) 1676 self.assertAlmostEqual(1.0, rec_low.eval(), places=5) 1677 self.assertAlmostEqual(0.0, rec_high.eval(), places=5) 1678 1679 def testWeights2d(self): 1680 with self.test_session() as sess: 1681 predictions = constant_op.constant( 1682 [[1, 0], [1, 0]], shape=(2, 2), dtype=dtypes_lib.float32) 1683 labels = constant_op.constant([[0, 1], [1, 0]], shape=(2, 2)) 1684 weights = constant_op.constant( 1685 [[0, 0], [1, 1]], shape=(2, 2), dtype=dtypes_lib.float32) 1686 thresholds = [0.5, 1.1] 1687 prec, prec_op = metrics.precision_at_thresholds( 1688 labels, predictions, thresholds, weights=weights) 1689 rec, rec_op = metrics.recall_at_thresholds( 1690 labels, predictions, thresholds, weights=weights) 1691 1692 [prec_low, prec_high] = array_ops.split( 1693 value=prec, num_or_size_splits=2, axis=0) 1694 prec_low = array_ops.reshape(prec_low, shape=()) 1695 prec_high = array_ops.reshape(prec_high, shape=()) 1696 [rec_low, rec_high] = array_ops.split( 1697 value=rec, num_or_size_splits=2, axis=0) 1698 rec_low = array_ops.reshape(rec_low, shape=()) 1699 rec_high = array_ops.reshape(rec_high, shape=()) 1700 1701 sess.run(variables.local_variables_initializer()) 1702 sess.run([prec_op, rec_op]) 1703 1704 self.assertAlmostEqual(1.0, prec_low.eval(), places=5) 1705 self.assertAlmostEqual(0.0, prec_high.eval(), places=5) 1706 self.assertAlmostEqual(1.0, rec_low.eval(), places=5) 1707 self.assertAlmostEqual(0.0, rec_high.eval(), places=5) 1708 1709 def testExtremeThresholds(self): 1710 with self.test_session() as sess: 1711 predictions = constant_op.constant( 1712 [1, 0, 1, 0], shape=(1, 4), dtype=dtypes_lib.float32) 1713 labels = constant_op.constant([0, 1, 1, 1], shape=(1, 4)) 1714 thresholds = [-1.0, 2.0] # lower/higher than any values 1715 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1716 thresholds) 1717 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1718 thresholds) 1719 1720 [prec_low, prec_high] = array_ops.split( 1721 value=prec, num_or_size_splits=2, axis=0) 1722 [rec_low, rec_high] = array_ops.split( 1723 value=rec, num_or_size_splits=2, axis=0) 1724 1725 sess.run(variables.local_variables_initializer()) 1726 sess.run([prec_op, rec_op]) 1727 1728 self.assertAlmostEqual(0.75, prec_low.eval()) 1729 self.assertAlmostEqual(0.0, prec_high.eval()) 1730 self.assertAlmostEqual(1.0, rec_low.eval()) 1731 self.assertAlmostEqual(0.0, rec_high.eval()) 1732 1733 def testZeroLabelsPredictions(self): 1734 with self.test_session() as sess: 1735 predictions = array_ops.zeros([4], dtype=dtypes_lib.float32) 1736 labels = array_ops.zeros([4]) 1737 thresholds = [0.5] 1738 prec, prec_op = metrics.precision_at_thresholds(labels, predictions, 1739 thresholds) 1740 rec, rec_op = metrics.recall_at_thresholds(labels, predictions, 1741 thresholds) 1742 1743 sess.run(variables.local_variables_initializer()) 1744 sess.run([prec_op, rec_op]) 1745 1746 self.assertAlmostEqual(0, prec.eval(), 6) 1747 self.assertAlmostEqual(0, rec.eval(), 6) 1748 1749 def testWithMultipleUpdates(self): 1750 num_samples = 1000 1751 batch_size = 10 1752 num_batches = int(num_samples / batch_size) 1753 1754 # Create the labels and data. 1755 labels = np.random.randint(0, 2, size=(num_samples, 1)) 1756 noise = np.random.normal(0.0, scale=0.2, size=(num_samples, 1)) 1757 predictions = 0.4 + 0.2 * labels + noise 1758 predictions[predictions > 1] = 1 1759 predictions[predictions < 0] = 0 1760 thresholds = [0.3] 1761 1762 tp = 0 1763 fp = 0 1764 fn = 0 1765 tn = 0 1766 for i in range(num_samples): 1767 if predictions[i] > thresholds[0]: 1768 if labels[i] == 1: 1769 tp += 1 1770 else: 1771 fp += 1 1772 else: 1773 if labels[i] == 1: 1774 fn += 1 1775 else: 1776 tn += 1 1777 epsilon = 1e-7 1778 expected_prec = tp / (epsilon + tp + fp) 1779 expected_rec = tp / (epsilon + tp + fn) 1780 1781 labels = labels.astype(np.float32) 1782 predictions = predictions.astype(np.float32) 1783 1784 with self.test_session() as sess: 1785 # Reshape the data so its easy to queue up: 1786 predictions_batches = predictions.reshape((batch_size, num_batches)) 1787 labels_batches = labels.reshape((batch_size, num_batches)) 1788 1789 # Enqueue the data: 1790 predictions_queue = data_flow_ops.FIFOQueue( 1791 num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) 1792 labels_queue = data_flow_ops.FIFOQueue( 1793 num_batches, dtypes=dtypes_lib.float32, shapes=(batch_size,)) 1794 1795 for i in range(int(num_batches)): 1796 tf_prediction = constant_op.constant(predictions_batches[:, i]) 1797 tf_label = constant_op.constant(labels_batches[:, i]) 1798 sess.run([ 1799 predictions_queue.enqueue(tf_prediction), 1800 labels_queue.enqueue(tf_label) 1801 ]) 1802 1803 tf_predictions = predictions_queue.dequeue() 1804 tf_labels = labels_queue.dequeue() 1805 1806 prec, prec_op = metrics.precision_at_thresholds(tf_labels, tf_predictions, 1807 thresholds) 1808 rec, rec_op = metrics.recall_at_thresholds(tf_labels, tf_predictions, 1809 thresholds) 1810 1811 sess.run(variables.local_variables_initializer()) 1812 for _ in range(int(num_samples / batch_size)): 1813 sess.run([prec_op, rec_op]) 1814 # Since this is only approximate, we can't expect a 6 digits match. 1815 # Although with higher number of samples/thresholds we should see the 1816 # accuracy improving 1817 self.assertAlmostEqual(expected_prec, prec.eval(), 2) 1818 self.assertAlmostEqual(expected_rec, rec.eval(), 2) 1819 1820 1821 def _test_precision_at_k(predictions, 1822 labels, 1823 k, 1824 expected, 1825 class_id=None, 1826 weights=None, 1827 test_case=None): 1828 with ops.Graph().as_default() as g, test_case.test_session(g): 1829 if weights is not None: 1830 weights = constant_op.constant(weights, dtypes_lib.float32) 1831 metric, update = metrics.precision_at_k( 1832 predictions=constant_op.constant(predictions, dtypes_lib.float32), 1833 labels=labels, 1834 k=k, 1835 class_id=class_id, 1836 weights=weights) 1837 1838 # Fails without initialized vars. 1839 test_case.assertRaises(errors_impl.OpError, metric.eval) 1840 test_case.assertRaises(errors_impl.OpError, update.eval) 1841 variables.variables_initializer(variables.local_variables()).run() 1842 1843 # Run per-step op and assert expected values. 1844 if math.isnan(expected): 1845 _assert_nan(test_case, update.eval()) 1846 _assert_nan(test_case, metric.eval()) 1847 else: 1848 test_case.assertEqual(expected, update.eval()) 1849 test_case.assertEqual(expected, metric.eval()) 1850 1851 1852 def _test_precision_at_top_k( 1853 predictions_idx, 1854 labels, 1855 expected, 1856 k=None, 1857 class_id=None, 1858 weights=None, 1859 test_case=None): 1860 with ops.Graph().as_default() as g, test_case.test_session(g): 1861 if weights is not None: 1862 weights = constant_op.constant(weights, dtypes_lib.float32) 1863 metric, update = metrics.precision_at_top_k( 1864 predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32), 1865 labels=labels, 1866 k=k, 1867 class_id=class_id, 1868 weights=weights) 1869 1870 # Fails without initialized vars. 1871 test_case.assertRaises(errors_impl.OpError, metric.eval) 1872 test_case.assertRaises(errors_impl.OpError, update.eval) 1873 variables.variables_initializer(variables.local_variables()).run() 1874 1875 # Run per-step op and assert expected values. 1876 if math.isnan(expected): 1877 test_case.assertTrue(math.isnan(update.eval())) 1878 test_case.assertTrue(math.isnan(metric.eval())) 1879 else: 1880 test_case.assertEqual(expected, update.eval()) 1881 test_case.assertEqual(expected, metric.eval()) 1882 1883 1884 def _test_average_precision_at_k(predictions, 1885 labels, 1886 k, 1887 expected, 1888 weights=None, 1889 test_case=None): 1890 with ops.Graph().as_default() as g, test_case.test_session(g): 1891 if weights is not None: 1892 weights = constant_op.constant(weights, dtypes_lib.float32) 1893 predictions = constant_op.constant(predictions, dtypes_lib.float32) 1894 metric, update = metrics.average_precision_at_k( 1895 labels, predictions, k, weights=weights) 1896 1897 # Fails without initialized vars. 1898 test_case.assertRaises(errors_impl.OpError, metric.eval) 1899 test_case.assertRaises(errors_impl.OpError, update.eval) 1900 variables.variables_initializer(variables.local_variables()).run() 1901 1902 # Run per-step op and assert expected values. 1903 if math.isnan(expected): 1904 _assert_nan(test_case, update.eval()) 1905 _assert_nan(test_case, metric.eval()) 1906 else: 1907 test_case.assertAlmostEqual(expected, update.eval()) 1908 test_case.assertAlmostEqual(expected, metric.eval()) 1909 1910 1911 class SingleLabelPrecisionAtKTest(test.TestCase): 1912 1913 def setUp(self): 1914 self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4)) 1915 self._predictions_idx = [[3], [3]] 1916 indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0)) 1917 class_labels = (3, 2) 1918 # Sparse vs dense, and 1d vs 2d labels should all be handled the same. 1919 self._labels = ( 1920 _binary_2d_label_to_1d_sparse_value(indicator_labels), 1921 _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array( 1922 class_labels, dtype=np.int64), np.array( 1923 [[class_id] for class_id in class_labels], dtype=np.int64)) 1924 self._test_precision_at_k = functools.partial( 1925 _test_precision_at_k, test_case=self) 1926 self._test_precision_at_top_k = functools.partial( 1927 _test_precision_at_top_k, test_case=self) 1928 self._test_average_precision_at_k = functools.partial( 1929 _test_average_precision_at_k, test_case=self) 1930 1931 def test_at_k1_nan(self): 1932 for labels in self._labels: 1933 # Classes 0,1,2 have 0 predictions, classes -1 and 4 are out of range. 1934 for class_id in (-1, 0, 1, 2, 4): 1935 self._test_precision_at_k( 1936 self._predictions, labels, k=1, expected=NAN, class_id=class_id) 1937 self._test_precision_at_top_k( 1938 self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id) 1939 1940 def test_at_k1(self): 1941 for labels in self._labels: 1942 # Class 3: 1 label, 2 predictions, 1 correct. 1943 self._test_precision_at_k( 1944 self._predictions, labels, k=1, expected=1.0 / 2, class_id=3) 1945 self._test_precision_at_top_k( 1946 self._predictions_idx, labels, k=1, expected=1.0 / 2, class_id=3) 1947 1948 # All classes: 2 labels, 2 predictions, 1 correct. 1949 self._test_precision_at_k( 1950 self._predictions, labels, k=1, expected=1.0 / 2) 1951 self._test_precision_at_top_k( 1952 self._predictions_idx, labels, k=1, expected=1.0 / 2) 1953 self._test_average_precision_at_k( 1954 self._predictions, labels, k=1, expected=1.0 / 2) 1955 1956 1957 class MultiLabelPrecisionAtKTest(test.TestCase): 1958 1959 def setUp(self): 1960 self._test_precision_at_k = functools.partial( 1961 _test_precision_at_k, test_case=self) 1962 self._test_precision_at_top_k = functools.partial( 1963 _test_precision_at_top_k, test_case=self) 1964 self._test_average_precision_at_k = functools.partial( 1965 _test_average_precision_at_k, test_case=self) 1966 1967 def test_average_precision(self): 1968 # Example 1. 1969 # Matches example here: 1970 # fastml.com/what-you-wanted-to-know-about-mean-average-precision 1971 labels_ex1 = (0, 1, 2, 3, 4) 1972 labels = np.array([labels_ex1], dtype=np.int64) 1973 predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3) 1974 predictions = (predictions_ex1,) 1975 predictions_idx_ex1 = (5, 3, 6, 0, 1) 1976 precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4) 1977 avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3, 1978 (precision_ex1[1] + precision_ex1[3]) / 4) 1979 for i in xrange(4): 1980 k = i + 1 1981 self._test_precision_at_k( 1982 predictions, labels, k, expected=precision_ex1[i]) 1983 self._test_precision_at_top_k( 1984 (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i]) 1985 self._test_average_precision_at_k( 1986 predictions, labels, k, expected=avg_precision_ex1[i]) 1987 1988 # Example 2. 1989 labels_ex2 = (0, 2, 4, 5, 6) 1990 labels = np.array([labels_ex2], dtype=np.int64) 1991 predictions_ex2 = (0.3, 0.5, 0.0, 0.4, 0.0, 0.1, 0.2) 1992 predictions = (predictions_ex2,) 1993 predictions_idx_ex2 = (1, 3, 0, 6, 5) 1994 precision_ex2 = (0.0 / 1, 0.0 / 2, 1.0 / 3, 2.0 / 4) 1995 avg_precision_ex2 = (0.0 / 1, 0.0 / 2, precision_ex2[2] / 3, 1996 (precision_ex2[2] + precision_ex2[3]) / 4) 1997 for i in xrange(4): 1998 k = i + 1 1999 self._test_precision_at_k( 2000 predictions, labels, k, expected=precision_ex2[i]) 2001 self._test_precision_at_top_k( 2002 (predictions_idx_ex2[:k],), labels, k=k, expected=precision_ex2[i]) 2003 self._test_average_precision_at_k( 2004 predictions, labels, k, expected=avg_precision_ex2[i]) 2005 2006 # Both examples, we expect both precision and average precision to be the 2007 # average of the 2 examples. 2008 labels = np.array([labels_ex1, labels_ex2], dtype=np.int64) 2009 predictions = (predictions_ex1, predictions_ex2) 2010 streaming_precision = [(ex1 + ex2) / 2 2011 for ex1, ex2 in zip(precision_ex1, precision_ex2)] 2012 streaming_average_precision = [ 2013 (ex1 + ex2) / 2 2014 for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2) 2015 ] 2016 for i in xrange(4): 2017 k = i + 1 2018 predictions_idx = (predictions_idx_ex1[:k], predictions_idx_ex2[:k]) 2019 self._test_precision_at_k( 2020 predictions, labels, k, expected=streaming_precision[i]) 2021 self._test_precision_at_top_k( 2022 predictions_idx, labels, k=k, expected=streaming_precision[i]) 2023 self._test_average_precision_at_k( 2024 predictions, labels, k, expected=streaming_average_precision[i]) 2025 2026 # Weighted examples, we expect streaming average precision to be the 2027 # weighted average of the 2 examples. 2028 weights = (0.3, 0.6) 2029 streaming_average_precision = [ 2030 (weights[0] * ex1 + weights[1] * ex2) / (weights[0] + weights[1]) 2031 for ex1, ex2 in zip(avg_precision_ex1, avg_precision_ex2) 2032 ] 2033 for i in xrange(4): 2034 k = i + 1 2035 self._test_average_precision_at_k( 2036 predictions, 2037 labels, 2038 k, 2039 expected=streaming_average_precision[i], 2040 weights=weights) 2041 2042 def test_average_precision_some_labels_out_of_range(self): 2043 """Tests that labels outside the [0, n_classes) range are ignored.""" 2044 labels_ex1 = (-1, 0, 1, 2, 3, 4, 7) 2045 labels = np.array([labels_ex1], dtype=np.int64) 2046 predictions_ex1 = (0.2, 0.1, 0.0, 0.4, 0.0, 0.5, 0.3) 2047 predictions = (predictions_ex1,) 2048 predictions_idx_ex1 = (5, 3, 6, 0, 1) 2049 precision_ex1 = (0.0 / 1, 1.0 / 2, 1.0 / 3, 2.0 / 4) 2050 avg_precision_ex1 = (0.0 / 1, precision_ex1[1] / 2, precision_ex1[1] / 3, 2051 (precision_ex1[1] + precision_ex1[3]) / 4) 2052 for i in xrange(4): 2053 k = i + 1 2054 self._test_precision_at_k( 2055 predictions, labels, k, expected=precision_ex1[i]) 2056 self._test_precision_at_top_k( 2057 (predictions_idx_ex1[:k],), labels, k=k, expected=precision_ex1[i]) 2058 self._test_average_precision_at_k( 2059 predictions, labels, k, expected=avg_precision_ex1[i]) 2060 2061 def test_three_labels_at_k5_no_predictions(self): 2062 predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2063 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] 2064 predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]] 2065 sparse_labels = _binary_2d_label_to_2d_sparse_value( 2066 [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) 2067 dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) 2068 2069 for labels in (sparse_labels, dense_labels): 2070 # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range. 2071 for class_id in (-1, 1, 3, 8, 10): 2072 self._test_precision_at_k( 2073 predictions, labels, k=5, expected=NAN, class_id=class_id) 2074 self._test_precision_at_top_k( 2075 predictions_idx, labels, k=5, expected=NAN, class_id=class_id) 2076 2077 def test_three_labels_at_k5_no_labels(self): 2078 predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2079 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] 2080 predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]] 2081 sparse_labels = _binary_2d_label_to_2d_sparse_value( 2082 [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) 2083 dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) 2084 2085 for labels in (sparse_labels, dense_labels): 2086 # Classes 0,4,6,9: 0 labels, >=1 prediction. 2087 for class_id in (0, 4, 6, 9): 2088 self._test_precision_at_k( 2089 predictions, labels, k=5, expected=0.0, class_id=class_id) 2090 self._test_precision_at_top_k( 2091 predictions_idx, labels, k=5, expected=0.0, class_id=class_id) 2092 2093 def test_three_labels_at_k5(self): 2094 predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2095 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] 2096 predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]] 2097 sparse_labels = _binary_2d_label_to_2d_sparse_value( 2098 [[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]]) 2099 dense_labels = np.array([[2, 7, 8], [1, 2, 5]], dtype=np.int64) 2100 2101 for labels in (sparse_labels, dense_labels): 2102 # Class 2: 2 labels, 2 correct predictions. 2103 self._test_precision_at_k( 2104 predictions, labels, k=5, expected=2.0 / 2, class_id=2) 2105 self._test_precision_at_top_k( 2106 predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2) 2107 2108 # Class 5: 1 label, 1 correct prediction. 2109 self._test_precision_at_k( 2110 predictions, labels, k=5, expected=1.0 / 1, class_id=5) 2111 self._test_precision_at_top_k( 2112 predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5) 2113 2114 # Class 7: 1 label, 1 incorrect prediction. 2115 self._test_precision_at_k( 2116 predictions, labels, k=5, expected=0.0 / 1, class_id=7) 2117 self._test_precision_at_top_k( 2118 predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7) 2119 2120 # All classes: 10 predictions, 3 correct. 2121 self._test_precision_at_k( 2122 predictions, labels, k=5, expected=3.0 / 10) 2123 self._test_precision_at_top_k( 2124 predictions_idx, labels, k=5, expected=3.0 / 10) 2125 2126 def test_three_labels_at_k5_some_out_of_range(self): 2127 """Tests that labels outside the [0, n_classes) range are ignored.""" 2128 predictions = [[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2129 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]] 2130 predictions_idx = [[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]] 2131 sp_labels = sparse_tensor.SparseTensorValue( 2132 indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], 2133 [1, 3]], 2134 # values -1 and 10 are outside the [0, n_classes) range and are ignored. 2135 values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64), 2136 dense_shape=[2, 4]) 2137 2138 # Class 2: 2 labels, 2 correct predictions. 2139 self._test_precision_at_k( 2140 predictions, sp_labels, k=5, expected=2.0 / 2, class_id=2) 2141 self._test_precision_at_top_k( 2142 predictions_idx, sp_labels, k=5, expected=2.0 / 2, class_id=2) 2143 2144 # Class 5: 1 label, 1 correct prediction. 2145 self._test_precision_at_k( 2146 predictions, sp_labels, k=5, expected=1.0 / 1, class_id=5) 2147 self._test_precision_at_top_k( 2148 predictions_idx, sp_labels, k=5, expected=1.0 / 1, class_id=5) 2149 2150 # Class 7: 1 label, 1 incorrect prediction. 2151 self._test_precision_at_k( 2152 predictions, sp_labels, k=5, expected=0.0 / 1, class_id=7) 2153 self._test_precision_at_top_k( 2154 predictions_idx, sp_labels, k=5, expected=0.0 / 1, class_id=7) 2155 2156 # All classes: 10 predictions, 3 correct. 2157 self._test_precision_at_k( 2158 predictions, sp_labels, k=5, expected=3.0 / 10) 2159 self._test_precision_at_top_k( 2160 predictions_idx, sp_labels, k=5, expected=3.0 / 10) 2161 2162 def test_3d_nan(self): 2163 predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2164 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], 2165 [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], 2166 [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] 2167 predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]], 2168 [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]] 2169 labels = _binary_3d_label_to_sparse_value( 2170 [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], 2171 [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) 2172 2173 # Classes 1,3,8 have 0 predictions, classes -1 and 10 are out of range. 2174 for class_id in (-1, 1, 3, 8, 10): 2175 self._test_precision_at_k( 2176 predictions, labels, k=5, expected=NAN, class_id=class_id) 2177 self._test_precision_at_top_k( 2178 predictions_idx, labels, k=5, expected=NAN, class_id=class_id) 2179 2180 def test_3d_no_labels(self): 2181 predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2182 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], 2183 [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], 2184 [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] 2185 predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]], 2186 [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]] 2187 labels = _binary_3d_label_to_sparse_value( 2188 [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], 2189 [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) 2190 2191 # Classes 0,4,6,9: 0 labels, >=1 prediction. 2192 for class_id in (0, 4, 6, 9): 2193 self._test_precision_at_k( 2194 predictions, labels, k=5, expected=0.0, class_id=class_id) 2195 self._test_precision_at_top_k( 2196 predictions_idx, labels, k=5, expected=0.0, class_id=class_id) 2197 2198 def test_3d(self): 2199 predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2200 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], 2201 [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], 2202 [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] 2203 predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]], 2204 [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]] 2205 labels = _binary_3d_label_to_sparse_value( 2206 [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], 2207 [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) 2208 2209 # Class 2: 4 predictions, all correct. 2210 self._test_precision_at_k( 2211 predictions, labels, k=5, expected=4.0 / 4, class_id=2) 2212 self._test_precision_at_top_k( 2213 predictions_idx, labels, k=5, expected=4.0 / 4, class_id=2) 2214 2215 # Class 5: 2 predictions, both correct. 2216 self._test_precision_at_k( 2217 predictions, labels, k=5, expected=2.0 / 2, class_id=5) 2218 self._test_precision_at_top_k( 2219 predictions_idx, labels, k=5, expected=2.0 / 2, class_id=5) 2220 2221 # Class 7: 2 predictions, 1 correct. 2222 self._test_precision_at_k( 2223 predictions, labels, k=5, expected=1.0 / 2, class_id=7) 2224 self._test_precision_at_top_k( 2225 predictions_idx, labels, k=5, expected=1.0 / 2, class_id=7) 2226 2227 # All classes: 20 predictions, 7 correct. 2228 self._test_precision_at_k( 2229 predictions, labels, k=5, expected=7.0 / 20) 2230 self._test_precision_at_top_k( 2231 predictions_idx, labels, k=5, expected=7.0 / 20) 2232 2233 def test_3d_ignore_some(self): 2234 predictions = [[[0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9], 2235 [0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6]], 2236 [[0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6], 2237 [0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9]]] 2238 predictions_idx = [[[9, 4, 6, 2, 0], [5, 7, 2, 9, 6]], 2239 [[5, 7, 2, 9, 6], [9, 4, 6, 2, 0]]] 2240 labels = _binary_3d_label_to_sparse_value( 2241 [[[0, 0, 1, 0, 0, 0, 0, 1, 1, 0], [0, 1, 1, 0, 0, 1, 0, 0, 0, 0]], 2242 [[0, 1, 1, 0, 0, 1, 0, 1, 0, 0], [0, 0, 1, 0, 0, 0, 0, 0, 1, 0]]]) 2243 2244 # Class 2: 2 predictions, both correct. 2245 self._test_precision_at_k( 2246 predictions, labels, k=5, expected=2.0 / 2.0, class_id=2, 2247 weights=[[1], [0]]) 2248 self._test_precision_at_top_k( 2249 predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2, 2250 weights=[[1], [0]]) 2251 2252 # Class 2: 2 predictions, both correct. 2253 self._test_precision_at_k( 2254 predictions, labels, k=5, expected=2.0 / 2.0, class_id=2, 2255 weights=[[0], [1]]) 2256 self._test_precision_at_top_k( 2257 predictions_idx, labels, k=5, expected=2.0 / 2.0, class_id=2, 2258 weights=[[0], [1]]) 2259 2260 # Class 7: 1 incorrect prediction. 2261 self._test_precision_at_k( 2262 predictions, labels, k=5, expected=0.0 / 1.0, class_id=7, 2263 weights=[[1], [0]]) 2264 self._test_precision_at_top_k( 2265 predictions_idx, labels, k=5, expected=0.0 / 1.0, class_id=7, 2266 weights=[[1], [0]]) 2267 2268 # Class 7: 1 correct prediction. 2269 self._test_precision_at_k( 2270 predictions, labels, k=5, expected=1.0 / 1.0, class_id=7, 2271 weights=[[0], [1]]) 2272 self._test_precision_at_top_k( 2273 predictions_idx, labels, k=5, expected=1.0 / 1.0, class_id=7, 2274 weights=[[0], [1]]) 2275 2276 # Class 7: no predictions. 2277 self._test_precision_at_k( 2278 predictions, labels, k=5, expected=NAN, class_id=7, 2279 weights=[[1, 0], [0, 1]]) 2280 self._test_precision_at_top_k( 2281 predictions_idx, labels, k=5, expected=NAN, class_id=7, 2282 weights=[[1, 0], [0, 1]]) 2283 2284 # Class 7: 2 predictions, 1 correct. 2285 self._test_precision_at_k( 2286 predictions, labels, k=5, expected=1.0 / 2.0, class_id=7, 2287 weights=[[0, 1], [1, 0]]) 2288 self._test_precision_at_top_k( 2289 predictions_idx, labels, k=5, expected=1.0 / 2.0, class_id=7, 2290 weights=[[0, 1], [1, 0]]) 2291 2292 2293 def _test_recall_at_k(predictions, 2294 labels, 2295 k, 2296 expected, 2297 class_id=None, 2298 weights=None, 2299 test_case=None): 2300 with ops.Graph().as_default() as g, test_case.test_session(g): 2301 if weights is not None: 2302 weights = constant_op.constant(weights, dtypes_lib.float32) 2303 metric, update = metrics.recall_at_k( 2304 predictions=constant_op.constant(predictions, dtypes_lib.float32), 2305 labels=labels, 2306 k=k, 2307 class_id=class_id, 2308 weights=weights) 2309 2310 # Fails without initialized vars. 2311 test_case.assertRaises(errors_impl.OpError, metric.eval) 2312 test_case.assertRaises(errors_impl.OpError, update.eval) 2313 variables.variables_initializer(variables.local_variables()).run() 2314 2315 # Run per-step op and assert expected values. 2316 if math.isnan(expected): 2317 _assert_nan(test_case, update.eval()) 2318 _assert_nan(test_case, metric.eval()) 2319 else: 2320 test_case.assertEqual(expected, update.eval()) 2321 test_case.assertEqual(expected, metric.eval()) 2322 2323 2324 def _test_recall_at_top_k( 2325 predictions_idx, 2326 labels, 2327 expected, 2328 k=None, 2329 class_id=None, 2330 weights=None, 2331 test_case=None): 2332 with ops.Graph().as_default() as g, test_case.test_session(g): 2333 if weights is not None: 2334 weights = constant_op.constant(weights, dtypes_lib.float32) 2335 metric, update = metrics.recall_at_top_k( 2336 predictions_idx=constant_op.constant(predictions_idx, dtypes_lib.int32), 2337 labels=labels, 2338 k=k, 2339 class_id=class_id, 2340 weights=weights) 2341 2342 # Fails without initialized vars. 2343 test_case.assertRaises(errors_impl.OpError, metric.eval) 2344 test_case.assertRaises(errors_impl.OpError, update.eval) 2345 variables.variables_initializer(variables.local_variables()).run() 2346 2347 # Run per-step op and assert expected values. 2348 if math.isnan(expected): 2349 _assert_nan(test_case, update.eval()) 2350 _assert_nan(test_case, metric.eval()) 2351 else: 2352 test_case.assertEqual(expected, update.eval()) 2353 test_case.assertEqual(expected, metric.eval()) 2354 2355 2356 class SingleLabelRecallAtKTest(test.TestCase): 2357 2358 def setUp(self): 2359 self._predictions = ((0.1, 0.3, 0.2, 0.4), (0.1, 0.2, 0.3, 0.4)) 2360 self._predictions_idx = [[3], [3]] 2361 indicator_labels = ((0, 0, 0, 1), (0, 0, 1, 0)) 2362 class_labels = (3, 2) 2363 # Sparse vs dense, and 1d vs 2d labels should all be handled the same. 2364 self._labels = ( 2365 _binary_2d_label_to_1d_sparse_value(indicator_labels), 2366 _binary_2d_label_to_2d_sparse_value(indicator_labels), np.array( 2367 class_labels, dtype=np.int64), np.array( 2368 [[class_id] for class_id in class_labels], dtype=np.int64)) 2369 self._test_recall_at_k = functools.partial( 2370 _test_recall_at_k, test_case=self) 2371 self._test_recall_at_top_k = functools.partial( 2372 _test_recall_at_top_k, test_case=self) 2373 2374 def test_at_k1_nan(self): 2375 # Classes 0,1 have 0 labels, 0 predictions, classes -1 and 4 are out of 2376 # range. 2377 for labels in self._labels: 2378 for class_id in (-1, 0, 1, 4): 2379 self._test_recall_at_k( 2380 self._predictions, labels, k=1, expected=NAN, class_id=class_id) 2381 self._test_recall_at_top_k( 2382 self._predictions_idx, labels, k=1, expected=NAN, class_id=class_id) 2383 2384 def test_at_k1_no_predictions(self): 2385 for labels in self._labels: 2386 # Class 2: 0 predictions. 2387 self._test_recall_at_k( 2388 self._predictions, labels, k=1, expected=0.0, class_id=2) 2389 self._test_recall_at_top_k( 2390 self._predictions_idx, labels, k=1, expected=0.0, class_id=2) 2391 2392 def test_one_label_at_k1(self): 2393 for labels in self._labels: 2394 # Class 3: 1 label, 2 predictions, 1 correct. 2395 self._test_recall_at_k( 2396 self._predictions, labels, k=1, expected=1.0 / 1, class_id=3) 2397 self._test_recall_at_top_k( 2398 self._predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3) 2399 2400 # All classes: 2 labels, 2 predictions, 1 correct. 2401 self._test_recall_at_k(self._predictions, labels, k=1, expected=1.0 / 2) 2402 self._test_recall_at_top_k( 2403 self._predictions_idx, labels, k=1, expected=1.0 / 2) 2404 2405 def test_one_label_at_k1_weighted_class_id3(self): 2406 predictions = self._predictions 2407 predictions_idx = self._predictions_idx 2408 for labels in self._labels: 2409 # Class 3: 1 label, 2 predictions, 1 correct. 2410 self._test_recall_at_k( 2411 predictions, labels, k=1, expected=NAN, class_id=3, weights=(0.0,)) 2412 self._test_recall_at_top_k( 2413 predictions_idx, labels, k=1, expected=NAN, class_id=3, 2414 weights=(0.0,)) 2415 self._test_recall_at_k( 2416 predictions, labels, k=1, expected=1.0 / 1, class_id=3, 2417 weights=(1.0,)) 2418 self._test_recall_at_top_k( 2419 predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3, 2420 weights=(1.0,)) 2421 self._test_recall_at_k( 2422 predictions, labels, k=1, expected=1.0 / 1, class_id=3, 2423 weights=(2.0,)) 2424 self._test_recall_at_top_k( 2425 predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3, 2426 weights=(2.0,)) 2427 self._test_recall_at_k( 2428 predictions, labels, k=1, expected=NAN, class_id=3, 2429 weights=(0.0, 1.0)) 2430 self._test_recall_at_top_k( 2431 predictions_idx, labels, k=1, expected=NAN, class_id=3, 2432 weights=(0.0, 1.0)) 2433 self._test_recall_at_k( 2434 predictions, labels, k=1, expected=1.0 / 1, class_id=3, 2435 weights=(1.0, 0.0)) 2436 self._test_recall_at_top_k( 2437 predictions_idx, labels, k=1, expected=1.0 / 1, class_id=3, 2438 weights=(1.0, 0.0)) 2439 self._test_recall_at_k( 2440 predictions, labels, k=1, expected=2.0 / 2, class_id=3, 2441 weights=(2.0, 3.0)) 2442 self._test_recall_at_top_k( 2443 predictions_idx, labels, k=1, expected=2.0 / 2, class_id=3, 2444 weights=(2.0, 3.0)) 2445 2446 def test_one_label_at_k1_weighted(self): 2447 predictions = self._predictions 2448 predictions_idx = self._predictions_idx 2449 for labels in self._labels: 2450 # All classes: 2 labels, 2 predictions, 1 correct. 2451 self._test_recall_at_k( 2452 predictions, labels, k=1, expected=NAN, weights=(0.0,)) 2453 self._test_recall_at_top_k( 2454 predictions_idx, labels, k=1, expected=NAN, weights=(0.0,)) 2455 self._test_recall_at_k( 2456 predictions, labels, k=1, expected=1.0 / 2, weights=(1.0,)) 2457 self._test_recall_at_top_k( 2458 predictions_idx, labels, k=1, expected=1.0 / 2, weights=(1.0,)) 2459 self._test_recall_at_k( 2460 predictions, labels, k=1, expected=1.0 / 2, weights=(2.0,)) 2461 self._test_recall_at_top_k( 2462 predictions_idx, labels, k=1, expected=1.0 / 2, weights=(2.0,)) 2463 self._test_recall_at_k( 2464 predictions, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) 2465 self._test_recall_at_top_k( 2466 predictions_idx, labels, k=1, expected=1.0 / 1, weights=(1.0, 0.0)) 2467 self._test_recall_at_k( 2468 predictions, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) 2469 self._test_recall_at_top_k( 2470 predictions_idx, labels, k=1, expected=0.0 / 1, weights=(0.0, 1.0)) 2471 self._test_recall_at_k( 2472 predictions, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) 2473 self._test_recall_at_top_k( 2474 predictions_idx, labels, k=1, expected=2.0 / 5, weights=(2.0, 3.0)) 2475 2476 2477 class MultiLabel2dRecallAtKTest(test.TestCase): 2478 2479 def setUp(self): 2480 self._predictions = ((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9), 2481 (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6)) 2482 self._predictions_idx = ((9, 4, 6, 2, 0), (5, 7, 2, 9, 6)) 2483 indicator_labels = ((0, 0, 1, 0, 0, 0, 0, 1, 1, 0), 2484 (0, 1, 1, 0, 0, 1, 0, 0, 0, 0)) 2485 class_labels = ((2, 7, 8), (1, 2, 5)) 2486 # Sparse vs dense labels should be handled the same. 2487 self._labels = (_binary_2d_label_to_2d_sparse_value(indicator_labels), 2488 np.array( 2489 class_labels, dtype=np.int64)) 2490 self._test_recall_at_k = functools.partial( 2491 _test_recall_at_k, test_case=self) 2492 self._test_recall_at_top_k = functools.partial( 2493 _test_recall_at_top_k, test_case=self) 2494 2495 def test_at_k5_nan(self): 2496 for labels in self._labels: 2497 # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range. 2498 for class_id in (0, 3, 4, 6, 9, 10): 2499 self._test_recall_at_k( 2500 self._predictions, labels, k=5, expected=NAN, class_id=class_id) 2501 self._test_recall_at_top_k( 2502 self._predictions_idx, labels, k=5, expected=NAN, class_id=class_id) 2503 2504 def test_at_k5_no_predictions(self): 2505 for labels in self._labels: 2506 # Class 8: 1 label, no predictions. 2507 self._test_recall_at_k( 2508 self._predictions, labels, k=5, expected=0.0 / 1, class_id=8) 2509 self._test_recall_at_top_k( 2510 self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=8) 2511 2512 def test_at_k5(self): 2513 for labels in self._labels: 2514 # Class 2: 2 labels, both correct. 2515 self._test_recall_at_k( 2516 self._predictions, labels, k=5, expected=2.0 / 2, class_id=2) 2517 self._test_recall_at_top_k( 2518 self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2) 2519 2520 # Class 5: 1 label, incorrect. 2521 self._test_recall_at_k( 2522 self._predictions, labels, k=5, expected=1.0 / 1, class_id=5) 2523 self._test_recall_at_top_k( 2524 self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5) 2525 2526 # Class 7: 1 label, incorrect. 2527 self._test_recall_at_k( 2528 self._predictions, labels, k=5, expected=0.0 / 1, class_id=7) 2529 self._test_recall_at_top_k( 2530 self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7) 2531 2532 # All classes: 6 labels, 3 correct. 2533 self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 6) 2534 self._test_recall_at_top_k( 2535 self._predictions_idx, labels, k=5, expected=3.0 / 6) 2536 2537 def test_at_k5_some_out_of_range(self): 2538 """Tests that labels outside the [0, n_classes) count in denominator.""" 2539 labels = sparse_tensor.SparseTensorValue( 2540 indices=[[0, 0], [0, 1], [0, 2], [0, 3], [1, 0], [1, 1], [1, 2], 2541 [1, 3]], 2542 # values -1 and 10 are outside the [0, n_classes) range. 2543 values=np.array([2, 7, -1, 8, 1, 2, 5, 10], np.int64), 2544 dense_shape=[2, 4]) 2545 2546 # Class 2: 2 labels, both correct. 2547 self._test_recall_at_k( 2548 self._predictions, labels, k=5, expected=2.0 / 2, class_id=2) 2549 self._test_recall_at_top_k( 2550 self._predictions_idx, labels, k=5, expected=2.0 / 2, class_id=2) 2551 2552 # Class 5: 1 label, incorrect. 2553 self._test_recall_at_k( 2554 self._predictions, labels, k=5, expected=1.0 / 1, class_id=5) 2555 self._test_recall_at_top_k( 2556 self._predictions_idx, labels, k=5, expected=1.0 / 1, class_id=5) 2557 2558 # Class 7: 1 label, incorrect. 2559 self._test_recall_at_k( 2560 self._predictions, labels, k=5, expected=0.0 / 1, class_id=7) 2561 self._test_recall_at_top_k( 2562 self._predictions_idx, labels, k=5, expected=0.0 / 1, class_id=7) 2563 2564 # All classes: 8 labels, 3 correct. 2565 self._test_recall_at_k(self._predictions, labels, k=5, expected=3.0 / 8) 2566 self._test_recall_at_top_k( 2567 self._predictions_idx, labels, k=5, expected=3.0 / 8) 2568 2569 2570 class MultiLabel3dRecallAtKTest(test.TestCase): 2571 2572 def setUp(self): 2573 self._predictions = (((0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9), 2574 (0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6)), 2575 ((0.3, 0.0, 0.7, 0.2, 0.4, 0.9, 0.5, 0.8, 0.1, 0.6), 2576 (0.5, 0.1, 0.6, 0.3, 0.8, 0.0, 0.7, 0.2, 0.4, 0.9))) 2577 self._predictions_idx = (((9, 4, 6, 2, 0), (5, 7, 2, 9, 6)), 2578 ((5, 7, 2, 9, 6), (9, 4, 6, 2, 0))) 2579 # Note: We don't test dense labels here, since examples have different 2580 # numbers of labels. 2581 self._labels = _binary_3d_label_to_sparse_value((( 2582 (0, 0, 1, 0, 0, 0, 0, 1, 1, 0), (0, 1, 1, 0, 0, 1, 0, 0, 0, 0)), ( 2583 (0, 1, 1, 0, 0, 1, 0, 1, 0, 0), (0, 0, 1, 0, 0, 0, 0, 0, 1, 0)))) 2584 self._test_recall_at_k = functools.partial( 2585 _test_recall_at_k, test_case=self) 2586 self._test_recall_at_top_k = functools.partial( 2587 _test_recall_at_top_k, test_case=self) 2588 2589 def test_3d_nan(self): 2590 # Classes 0,3,4,6,9 have 0 labels, class 10 is out of range. 2591 for class_id in (0, 3, 4, 6, 9, 10): 2592 self._test_recall_at_k( 2593 self._predictions, self._labels, k=5, expected=NAN, class_id=class_id) 2594 self._test_recall_at_top_k( 2595 self._predictions_idx, self._labels, k=5, expected=NAN, 2596 class_id=class_id) 2597 2598 def test_3d_no_predictions(self): 2599 # Classes 1,8 have 0 predictions, >=1 label. 2600 for class_id in (1, 8): 2601 self._test_recall_at_k( 2602 self._predictions, self._labels, k=5, expected=0.0, class_id=class_id) 2603 self._test_recall_at_top_k( 2604 self._predictions_idx, self._labels, k=5, expected=0.0, 2605 class_id=class_id) 2606 2607 def test_3d(self): 2608 # Class 2: 4 labels, all correct. 2609 self._test_recall_at_k( 2610 self._predictions, self._labels, k=5, expected=4.0 / 4, class_id=2) 2611 self._test_recall_at_top_k( 2612 self._predictions_idx, self._labels, k=5, expected=4.0 / 4, 2613 class_id=2) 2614 2615 # Class 5: 2 labels, both correct. 2616 self._test_recall_at_k( 2617 self._predictions, self._labels, k=5, expected=2.0 / 2, class_id=5) 2618 self._test_recall_at_top_k( 2619 self._predictions_idx, self._labels, k=5, expected=2.0 / 2, 2620 class_id=5) 2621 2622 # Class 7: 2 labels, 1 incorrect. 2623 self._test_recall_at_k( 2624 self._predictions, self._labels, k=5, expected=1.0 / 2, class_id=7) 2625 self._test_recall_at_top_k( 2626 self._predictions_idx, self._labels, k=5, expected=1.0 / 2, 2627 class_id=7) 2628 2629 # All classes: 12 labels, 7 correct. 2630 self._test_recall_at_k( 2631 self._predictions, self._labels, k=5, expected=7.0 / 12) 2632 self._test_recall_at_top_k( 2633 self._predictions_idx, self._labels, k=5, expected=7.0 / 12) 2634 2635 def test_3d_ignore_all(self): 2636 for class_id in xrange(10): 2637 self._test_recall_at_k( 2638 self._predictions, self._labels, k=5, expected=NAN, class_id=class_id, 2639 weights=[[0], [0]]) 2640 self._test_recall_at_top_k( 2641 self._predictions_idx, self._labels, k=5, expected=NAN, 2642 class_id=class_id, weights=[[0], [0]]) 2643 self._test_recall_at_k( 2644 self._predictions, self._labels, k=5, expected=NAN, class_id=class_id, 2645 weights=[[0, 0], [0, 0]]) 2646 self._test_recall_at_top_k( 2647 self._predictions_idx, self._labels, k=5, expected=NAN, 2648 class_id=class_id, weights=[[0, 0], [0, 0]]) 2649 self._test_recall_at_k( 2650 self._predictions, self._labels, k=5, expected=NAN, weights=[[0], [0]]) 2651 self._test_recall_at_top_k( 2652 self._predictions_idx, self._labels, k=5, expected=NAN, 2653 weights=[[0], [0]]) 2654 self._test_recall_at_k( 2655 self._predictions, self._labels, k=5, expected=NAN, 2656 weights=[[0, 0], [0, 0]]) 2657 self._test_recall_at_top_k( 2658 self._predictions_idx, self._labels, k=5, expected=NAN, 2659 weights=[[0, 0], [0, 0]]) 2660 2661 def test_3d_ignore_some(self): 2662 # Class 2: 2 labels, both correct. 2663 self._test_recall_at_k( 2664 self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2, 2665 weights=[[1], [0]]) 2666 self._test_recall_at_top_k( 2667 self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0, 2668 class_id=2, weights=[[1], [0]]) 2669 2670 # Class 2: 2 labels, both correct. 2671 self._test_recall_at_k( 2672 self._predictions, self._labels, k=5, expected=2.0 / 2.0, class_id=2, 2673 weights=[[0], [1]]) 2674 self._test_recall_at_top_k( 2675 self._predictions_idx, self._labels, k=5, expected=2.0 / 2.0, 2676 class_id=2, weights=[[0], [1]]) 2677 2678 # Class 7: 1 label, correct. 2679 self._test_recall_at_k( 2680 self._predictions, self._labels, k=5, expected=1.0 / 1.0, class_id=7, 2681 weights=[[0], [1]]) 2682 self._test_recall_at_top_k( 2683 self._predictions_idx, self._labels, k=5, expected=1.0 / 1.0, 2684 class_id=7, weights=[[0], [1]]) 2685 2686 # Class 7: 1 label, incorrect. 2687 self._test_recall_at_k( 2688 self._predictions, self._labels, k=5, expected=0.0 / 1.0, class_id=7, 2689 weights=[[1], [0]]) 2690 self._test_recall_at_top_k( 2691 self._predictions_idx, self._labels, k=5, expected=0.0 / 1.0, 2692 class_id=7, weights=[[1], [0]]) 2693 2694 # Class 7: 2 labels, 1 correct. 2695 self._test_recall_at_k( 2696 self._predictions, self._labels, k=5, expected=1.0 / 2.0, class_id=7, 2697 weights=[[1, 0], [1, 0]]) 2698 self._test_recall_at_top_k( 2699 self._predictions_idx, self._labels, k=5, expected=1.0 / 2.0, 2700 class_id=7, weights=[[1, 0], [1, 0]]) 2701 2702 # Class 7: No labels. 2703 self._test_recall_at_k( 2704 self._predictions, self._labels, k=5, expected=NAN, class_id=7, 2705 weights=[[0, 1], [0, 1]]) 2706 self._test_recall_at_top_k( 2707 self._predictions_idx, self._labels, k=5, expected=NAN, class_id=7, 2708 weights=[[0, 1], [0, 1]]) 2709 2710 2711 class MeanAbsoluteErrorTest(test.TestCase): 2712 2713 def setUp(self): 2714 ops.reset_default_graph() 2715 2716 def testVars(self): 2717 metrics.mean_absolute_error( 2718 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 2719 _assert_metric_variables( 2720 self, ('mean_absolute_error/count:0', 'mean_absolute_error/total:0')) 2721 2722 def testMetricsCollection(self): 2723 my_collection_name = '__metrics__' 2724 mean, _ = metrics.mean_absolute_error( 2725 predictions=array_ops.ones((10, 1)), 2726 labels=array_ops.ones((10, 1)), 2727 metrics_collections=[my_collection_name]) 2728 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 2729 2730 def testUpdatesCollection(self): 2731 my_collection_name = '__updates__' 2732 _, update_op = metrics.mean_absolute_error( 2733 predictions=array_ops.ones((10, 1)), 2734 labels=array_ops.ones((10, 1)), 2735 updates_collections=[my_collection_name]) 2736 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 2737 2738 def testValueTensorIsIdempotent(self): 2739 predictions = random_ops.random_normal((10, 3), seed=1) 2740 labels = random_ops.random_normal((10, 3), seed=2) 2741 error, update_op = metrics.mean_absolute_error(labels, predictions) 2742 2743 with self.test_session() as sess: 2744 sess.run(variables.local_variables_initializer()) 2745 2746 # Run several updates. 2747 for _ in range(10): 2748 sess.run(update_op) 2749 2750 # Then verify idempotency. 2751 initial_error = error.eval() 2752 for _ in range(10): 2753 self.assertEqual(initial_error, error.eval()) 2754 2755 def testSingleUpdateWithErrorAndWeights(self): 2756 predictions = constant_op.constant( 2757 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 2758 labels = constant_op.constant( 2759 [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) 2760 weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) 2761 2762 error, update_op = metrics.mean_absolute_error(labels, predictions, weights) 2763 2764 with self.test_session() as sess: 2765 sess.run(variables.local_variables_initializer()) 2766 self.assertEqual(3, sess.run(update_op)) 2767 self.assertEqual(3, error.eval()) 2768 2769 2770 class MeanRelativeErrorTest(test.TestCase): 2771 2772 def setUp(self): 2773 ops.reset_default_graph() 2774 2775 def testVars(self): 2776 metrics.mean_relative_error( 2777 predictions=array_ops.ones((10, 1)), 2778 labels=array_ops.ones((10, 1)), 2779 normalizer=array_ops.ones((10, 1))) 2780 _assert_metric_variables( 2781 self, ('mean_relative_error/count:0', 'mean_relative_error/total:0')) 2782 2783 def testMetricsCollection(self): 2784 my_collection_name = '__metrics__' 2785 mean, _ = metrics.mean_relative_error( 2786 predictions=array_ops.ones((10, 1)), 2787 labels=array_ops.ones((10, 1)), 2788 normalizer=array_ops.ones((10, 1)), 2789 metrics_collections=[my_collection_name]) 2790 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 2791 2792 def testUpdatesCollection(self): 2793 my_collection_name = '__updates__' 2794 _, update_op = metrics.mean_relative_error( 2795 predictions=array_ops.ones((10, 1)), 2796 labels=array_ops.ones((10, 1)), 2797 normalizer=array_ops.ones((10, 1)), 2798 updates_collections=[my_collection_name]) 2799 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 2800 2801 def testValueTensorIsIdempotent(self): 2802 predictions = random_ops.random_normal((10, 3), seed=1) 2803 labels = random_ops.random_normal((10, 3), seed=2) 2804 normalizer = random_ops.random_normal((10, 3), seed=3) 2805 error, update_op = metrics.mean_relative_error(labels, predictions, 2806 normalizer) 2807 2808 with self.test_session() as sess: 2809 sess.run(variables.local_variables_initializer()) 2810 2811 # Run several updates. 2812 for _ in range(10): 2813 sess.run(update_op) 2814 2815 # Then verify idempotency. 2816 initial_error = error.eval() 2817 for _ in range(10): 2818 self.assertEqual(initial_error, error.eval()) 2819 2820 def testSingleUpdateNormalizedByLabels(self): 2821 np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32) 2822 np_labels = np.asarray([1, 3, 2, 3], dtype=np.float32) 2823 expected_error = np.mean( 2824 np.divide(np.absolute(np_predictions - np_labels), np_labels)) 2825 2826 predictions = constant_op.constant( 2827 np_predictions, shape=(1, 4), dtype=dtypes_lib.float32) 2828 labels = constant_op.constant(np_labels, shape=(1, 4)) 2829 2830 error, update_op = metrics.mean_relative_error( 2831 labels, predictions, normalizer=labels) 2832 2833 with self.test_session() as sess: 2834 sess.run(variables.local_variables_initializer()) 2835 self.assertEqual(expected_error, sess.run(update_op)) 2836 self.assertEqual(expected_error, error.eval()) 2837 2838 def testSingleUpdateNormalizedByZeros(self): 2839 np_predictions = np.asarray([2, 4, 6, 8], dtype=np.float32) 2840 2841 predictions = constant_op.constant( 2842 np_predictions, shape=(1, 4), dtype=dtypes_lib.float32) 2843 labels = constant_op.constant( 2844 [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) 2845 2846 error, update_op = metrics.mean_relative_error( 2847 labels, predictions, normalizer=array_ops.zeros_like(labels)) 2848 2849 with self.test_session() as sess: 2850 sess.run(variables.local_variables_initializer()) 2851 self.assertEqual(0.0, sess.run(update_op)) 2852 self.assertEqual(0.0, error.eval()) 2853 2854 2855 class MeanSquaredErrorTest(test.TestCase): 2856 2857 def setUp(self): 2858 ops.reset_default_graph() 2859 2860 def testVars(self): 2861 metrics.mean_squared_error( 2862 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 2863 _assert_metric_variables( 2864 self, ('mean_squared_error/count:0', 'mean_squared_error/total:0')) 2865 2866 def testMetricsCollection(self): 2867 my_collection_name = '__metrics__' 2868 mean, _ = metrics.mean_squared_error( 2869 predictions=array_ops.ones((10, 1)), 2870 labels=array_ops.ones((10, 1)), 2871 metrics_collections=[my_collection_name]) 2872 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 2873 2874 def testUpdatesCollection(self): 2875 my_collection_name = '__updates__' 2876 _, update_op = metrics.mean_squared_error( 2877 predictions=array_ops.ones((10, 1)), 2878 labels=array_ops.ones((10, 1)), 2879 updates_collections=[my_collection_name]) 2880 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 2881 2882 def testValueTensorIsIdempotent(self): 2883 predictions = random_ops.random_normal((10, 3), seed=1) 2884 labels = random_ops.random_normal((10, 3), seed=2) 2885 error, update_op = metrics.mean_squared_error(labels, predictions) 2886 2887 with self.test_session() as sess: 2888 sess.run(variables.local_variables_initializer()) 2889 2890 # Run several updates. 2891 for _ in range(10): 2892 sess.run(update_op) 2893 2894 # Then verify idempotency. 2895 initial_error = error.eval() 2896 for _ in range(10): 2897 self.assertEqual(initial_error, error.eval()) 2898 2899 def testSingleUpdateZeroError(self): 2900 predictions = array_ops.zeros((1, 3), dtype=dtypes_lib.float32) 2901 labels = array_ops.zeros((1, 3), dtype=dtypes_lib.float32) 2902 2903 error, update_op = metrics.mean_squared_error(labels, predictions) 2904 2905 with self.test_session() as sess: 2906 sess.run(variables.local_variables_initializer()) 2907 self.assertEqual(0, sess.run(update_op)) 2908 self.assertEqual(0, error.eval()) 2909 2910 def testSingleUpdateWithError(self): 2911 predictions = constant_op.constant( 2912 [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) 2913 labels = constant_op.constant( 2914 [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32) 2915 2916 error, update_op = metrics.mean_squared_error(labels, predictions) 2917 2918 with self.test_session() as sess: 2919 sess.run(variables.local_variables_initializer()) 2920 self.assertEqual(6, sess.run(update_op)) 2921 self.assertEqual(6, error.eval()) 2922 2923 def testSingleUpdateWithErrorAndWeights(self): 2924 predictions = constant_op.constant( 2925 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 2926 labels = constant_op.constant( 2927 [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) 2928 weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) 2929 2930 error, update_op = metrics.mean_squared_error(labels, predictions, weights) 2931 2932 with self.test_session() as sess: 2933 sess.run(variables.local_variables_initializer()) 2934 self.assertEqual(13, sess.run(update_op)) 2935 self.assertEqual(13, error.eval()) 2936 2937 def testMultipleBatchesOfSizeOne(self): 2938 with self.test_session() as sess: 2939 # Create the queue that populates the predictions. 2940 preds_queue = data_flow_ops.FIFOQueue( 2941 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 2942 _enqueue_vector(sess, preds_queue, [10, 8, 6]) 2943 _enqueue_vector(sess, preds_queue, [-4, 3, -1]) 2944 predictions = preds_queue.dequeue() 2945 2946 # Create the queue that populates the labels. 2947 labels_queue = data_flow_ops.FIFOQueue( 2948 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 2949 _enqueue_vector(sess, labels_queue, [1, 3, 2]) 2950 _enqueue_vector(sess, labels_queue, [2, 4, 6]) 2951 labels = labels_queue.dequeue() 2952 2953 error, update_op = metrics.mean_squared_error(labels, predictions) 2954 2955 sess.run(variables.local_variables_initializer()) 2956 sess.run(update_op) 2957 self.assertAlmostEqual(208.0 / 6, sess.run(update_op), 5) 2958 2959 self.assertAlmostEqual(208.0 / 6, error.eval(), 5) 2960 2961 def testMetricsComputedConcurrently(self): 2962 with self.test_session() as sess: 2963 # Create the queue that populates one set of predictions. 2964 preds_queue0 = data_flow_ops.FIFOQueue( 2965 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 2966 _enqueue_vector(sess, preds_queue0, [10, 8, 6]) 2967 _enqueue_vector(sess, preds_queue0, [-4, 3, -1]) 2968 predictions0 = preds_queue0.dequeue() 2969 2970 # Create the queue that populates one set of predictions. 2971 preds_queue1 = data_flow_ops.FIFOQueue( 2972 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 2973 _enqueue_vector(sess, preds_queue1, [0, 1, 1]) 2974 _enqueue_vector(sess, preds_queue1, [1, 1, 0]) 2975 predictions1 = preds_queue1.dequeue() 2976 2977 # Create the queue that populates one set of labels. 2978 labels_queue0 = data_flow_ops.FIFOQueue( 2979 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 2980 _enqueue_vector(sess, labels_queue0, [1, 3, 2]) 2981 _enqueue_vector(sess, labels_queue0, [2, 4, 6]) 2982 labels0 = labels_queue0.dequeue() 2983 2984 # Create the queue that populates another set of labels. 2985 labels_queue1 = data_flow_ops.FIFOQueue( 2986 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 2987 _enqueue_vector(sess, labels_queue1, [-5, -3, -1]) 2988 _enqueue_vector(sess, labels_queue1, [5, 4, 3]) 2989 labels1 = labels_queue1.dequeue() 2990 2991 mse0, update_op0 = metrics.mean_squared_error( 2992 labels0, predictions0, name='msd0') 2993 mse1, update_op1 = metrics.mean_squared_error( 2994 labels1, predictions1, name='msd1') 2995 2996 sess.run(variables.local_variables_initializer()) 2997 sess.run([update_op0, update_op1]) 2998 sess.run([update_op0, update_op1]) 2999 3000 mse0, mse1 = sess.run([mse0, mse1]) 3001 self.assertAlmostEqual(208.0 / 6, mse0, 5) 3002 self.assertAlmostEqual(79.0 / 6, mse1, 5) 3003 3004 def testMultipleMetricsOnMultipleBatchesOfSizeOne(self): 3005 with self.test_session() as sess: 3006 # Create the queue that populates the predictions. 3007 preds_queue = data_flow_ops.FIFOQueue( 3008 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3009 _enqueue_vector(sess, preds_queue, [10, 8, 6]) 3010 _enqueue_vector(sess, preds_queue, [-4, 3, -1]) 3011 predictions = preds_queue.dequeue() 3012 3013 # Create the queue that populates the labels. 3014 labels_queue = data_flow_ops.FIFOQueue( 3015 2, dtypes=dtypes_lib.float32, shapes=(1, 3)) 3016 _enqueue_vector(sess, labels_queue, [1, 3, 2]) 3017 _enqueue_vector(sess, labels_queue, [2, 4, 6]) 3018 labels = labels_queue.dequeue() 3019 3020 mae, ma_update_op = metrics.mean_absolute_error(labels, predictions) 3021 mse, ms_update_op = metrics.mean_squared_error(labels, predictions) 3022 3023 sess.run(variables.local_variables_initializer()) 3024 sess.run([ma_update_op, ms_update_op]) 3025 sess.run([ma_update_op, ms_update_op]) 3026 3027 self.assertAlmostEqual(32.0 / 6, mae.eval(), 5) 3028 self.assertAlmostEqual(208.0 / 6, mse.eval(), 5) 3029 3030 3031 class RootMeanSquaredErrorTest(test.TestCase): 3032 3033 def setUp(self): 3034 ops.reset_default_graph() 3035 3036 def testVars(self): 3037 metrics.root_mean_squared_error( 3038 predictions=array_ops.ones((10, 1)), labels=array_ops.ones((10, 1))) 3039 _assert_metric_variables( 3040 self, 3041 ('root_mean_squared_error/count:0', 'root_mean_squared_error/total:0')) 3042 3043 def testMetricsCollection(self): 3044 my_collection_name = '__metrics__' 3045 mean, _ = metrics.root_mean_squared_error( 3046 predictions=array_ops.ones((10, 1)), 3047 labels=array_ops.ones((10, 1)), 3048 metrics_collections=[my_collection_name]) 3049 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 3050 3051 def testUpdatesCollection(self): 3052 my_collection_name = '__updates__' 3053 _, update_op = metrics.root_mean_squared_error( 3054 predictions=array_ops.ones((10, 1)), 3055 labels=array_ops.ones((10, 1)), 3056 updates_collections=[my_collection_name]) 3057 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3058 3059 def testValueTensorIsIdempotent(self): 3060 predictions = random_ops.random_normal((10, 3), seed=1) 3061 labels = random_ops.random_normal((10, 3), seed=2) 3062 error, update_op = metrics.root_mean_squared_error(labels, predictions) 3063 3064 with self.test_session() as sess: 3065 sess.run(variables.local_variables_initializer()) 3066 3067 # Run several updates. 3068 for _ in range(10): 3069 sess.run(update_op) 3070 3071 # Then verify idempotency. 3072 initial_error = error.eval() 3073 for _ in range(10): 3074 self.assertEqual(initial_error, error.eval()) 3075 3076 def testSingleUpdateZeroError(self): 3077 with self.test_session() as sess: 3078 predictions = constant_op.constant( 3079 0.0, shape=(1, 3), dtype=dtypes_lib.float32) 3080 labels = constant_op.constant(0.0, shape=(1, 3), dtype=dtypes_lib.float32) 3081 3082 rmse, update_op = metrics.root_mean_squared_error(labels, predictions) 3083 3084 sess.run(variables.local_variables_initializer()) 3085 self.assertEqual(0, sess.run(update_op)) 3086 3087 self.assertEqual(0, rmse.eval()) 3088 3089 def testSingleUpdateWithError(self): 3090 with self.test_session() as sess: 3091 predictions = constant_op.constant( 3092 [2, 4, 6], shape=(1, 3), dtype=dtypes_lib.float32) 3093 labels = constant_op.constant( 3094 [1, 3, 2], shape=(1, 3), dtype=dtypes_lib.float32) 3095 3096 rmse, update_op = metrics.root_mean_squared_error(labels, predictions) 3097 3098 sess.run(variables.local_variables_initializer()) 3099 self.assertAlmostEqual(math.sqrt(6), update_op.eval(), 5) 3100 self.assertAlmostEqual(math.sqrt(6), rmse.eval(), 5) 3101 3102 def testSingleUpdateWithErrorAndWeights(self): 3103 with self.test_session() as sess: 3104 predictions = constant_op.constant( 3105 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 3106 labels = constant_op.constant( 3107 [1, 3, 2, 3], shape=(1, 4), dtype=dtypes_lib.float32) 3108 weights = constant_op.constant([0, 1, 0, 1], shape=(1, 4)) 3109 3110 rmse, update_op = metrics.root_mean_squared_error(labels, predictions, 3111 weights) 3112 3113 sess.run(variables.local_variables_initializer()) 3114 self.assertAlmostEqual(math.sqrt(13), sess.run(update_op)) 3115 3116 self.assertAlmostEqual(math.sqrt(13), rmse.eval(), 5) 3117 3118 3119 def _reweight(predictions, labels, weights): 3120 return (np.concatenate([[p] * int(w) for p, w in zip(predictions, weights)]), 3121 np.concatenate([[l] * int(w) for l, w in zip(labels, weights)])) 3122 3123 3124 class MeanCosineDistanceTest(test.TestCase): 3125 3126 def setUp(self): 3127 ops.reset_default_graph() 3128 3129 def testVars(self): 3130 metrics.mean_cosine_distance( 3131 predictions=array_ops.ones((10, 3)), 3132 labels=array_ops.ones((10, 3)), 3133 dim=1) 3134 _assert_metric_variables(self, ( 3135 'mean_cosine_distance/count:0', 3136 'mean_cosine_distance/total:0', 3137 )) 3138 3139 def testMetricsCollection(self): 3140 my_collection_name = '__metrics__' 3141 mean, _ = metrics.mean_cosine_distance( 3142 predictions=array_ops.ones((10, 3)), 3143 labels=array_ops.ones((10, 3)), 3144 dim=1, 3145 metrics_collections=[my_collection_name]) 3146 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 3147 3148 def testUpdatesCollection(self): 3149 my_collection_name = '__updates__' 3150 _, update_op = metrics.mean_cosine_distance( 3151 predictions=array_ops.ones((10, 3)), 3152 labels=array_ops.ones((10, 3)), 3153 dim=1, 3154 updates_collections=[my_collection_name]) 3155 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3156 3157 def testValueTensorIsIdempotent(self): 3158 predictions = random_ops.random_normal((10, 3), seed=1) 3159 labels = random_ops.random_normal((10, 3), seed=2) 3160 error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=1) 3161 3162 with self.test_session() as sess: 3163 sess.run(variables.local_variables_initializer()) 3164 3165 # Run several updates. 3166 for _ in range(10): 3167 sess.run(update_op) 3168 3169 # Then verify idempotency. 3170 initial_error = error.eval() 3171 for _ in range(10): 3172 self.assertEqual(initial_error, error.eval()) 3173 3174 def testSingleUpdateZeroError(self): 3175 np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0')) 3176 3177 predictions = constant_op.constant( 3178 np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32) 3179 labels = constant_op.constant( 3180 np_labels, shape=(1, 3, 3), dtype=dtypes_lib.float32) 3181 3182 error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2) 3183 3184 with self.test_session() as sess: 3185 sess.run(variables.local_variables_initializer()) 3186 self.assertEqual(0, sess.run(update_op)) 3187 self.assertEqual(0, error.eval()) 3188 3189 def testSingleUpdateWithError1(self): 3190 np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0')) 3191 np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0')) 3192 3193 predictions = constant_op.constant( 3194 np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3195 labels = constant_op.constant( 3196 np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3197 3198 error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2) 3199 3200 with self.test_session() as sess: 3201 sess.run(variables.local_variables_initializer()) 3202 self.assertAlmostEqual(1, sess.run(update_op), 5) 3203 self.assertAlmostEqual(1, error.eval(), 5) 3204 3205 def testSingleUpdateWithError2(self): 3206 np_predictions = np.matrix( 3207 ('0.819031913261206 0.567041924552012 0.087465312324590;' 3208 '-0.665139432070255 -0.739487441769973 -0.103671883216994;' 3209 '0.707106781186548 -0.707106781186548 0')) 3210 np_labels = np.matrix( 3211 ('0.819031913261206 0.567041924552012 0.087465312324590;' 3212 '0.665139432070255 0.739487441769973 0.103671883216994;' 3213 '0.707106781186548 0.707106781186548 0')) 3214 3215 predictions = constant_op.constant( 3216 np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3217 labels = constant_op.constant( 3218 np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3219 error, update_op = metrics.mean_cosine_distance(labels, predictions, dim=2) 3220 3221 with self.test_session() as sess: 3222 sess.run(variables.local_variables_initializer()) 3223 self.assertAlmostEqual(1.0, sess.run(update_op), 5) 3224 self.assertAlmostEqual(1.0, error.eval(), 5) 3225 3226 def testSingleUpdateWithErrorAndWeights1(self): 3227 np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0')) 3228 np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0')) 3229 3230 predictions = constant_op.constant( 3231 np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3232 labels = constant_op.constant( 3233 np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3234 weights = constant_op.constant( 3235 [1, 0, 0], shape=(3, 1, 1), dtype=dtypes_lib.float32) 3236 3237 error, update_op = metrics.mean_cosine_distance( 3238 labels, predictions, dim=2, weights=weights) 3239 3240 with self.test_session() as sess: 3241 sess.run(variables.local_variables_initializer()) 3242 self.assertEqual(0, sess.run(update_op)) 3243 self.assertEqual(0, error.eval()) 3244 3245 def testSingleUpdateWithErrorAndWeights2(self): 3246 np_predictions = np.matrix(('1 0 0;' '0 0 -1;' '1 0 0')) 3247 np_labels = np.matrix(('1 0 0;' '0 0 1;' '0 1 0')) 3248 3249 predictions = constant_op.constant( 3250 np_predictions, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3251 labels = constant_op.constant( 3252 np_labels, shape=(3, 1, 3), dtype=dtypes_lib.float32) 3253 weights = constant_op.constant( 3254 [0, 1, 1], shape=(3, 1, 1), dtype=dtypes_lib.float32) 3255 3256 error, update_op = metrics.mean_cosine_distance( 3257 labels, predictions, dim=2, weights=weights) 3258 3259 with self.test_session() as sess: 3260 sess.run(variables.local_variables_initializer()) 3261 self.assertEqual(1.5, update_op.eval()) 3262 self.assertEqual(1.5, error.eval()) 3263 3264 3265 class PcntBelowThreshTest(test.TestCase): 3266 3267 def setUp(self): 3268 ops.reset_default_graph() 3269 3270 def testVars(self): 3271 metrics.percentage_below(values=array_ops.ones((10,)), threshold=2) 3272 _assert_metric_variables(self, ( 3273 'percentage_below_threshold/count:0', 3274 'percentage_below_threshold/total:0', 3275 )) 3276 3277 def testMetricsCollection(self): 3278 my_collection_name = '__metrics__' 3279 mean, _ = metrics.percentage_below( 3280 values=array_ops.ones((10,)), 3281 threshold=2, 3282 metrics_collections=[my_collection_name]) 3283 self.assertListEqual(ops.get_collection(my_collection_name), [mean]) 3284 3285 def testUpdatesCollection(self): 3286 my_collection_name = '__updates__' 3287 _, update_op = metrics.percentage_below( 3288 values=array_ops.ones((10,)), 3289 threshold=2, 3290 updates_collections=[my_collection_name]) 3291 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3292 3293 def testOneUpdate(self): 3294 with self.test_session() as sess: 3295 values = constant_op.constant( 3296 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 3297 3298 pcnt0, update_op0 = metrics.percentage_below(values, 100, name='high') 3299 pcnt1, update_op1 = metrics.percentage_below(values, 7, name='medium') 3300 pcnt2, update_op2 = metrics.percentage_below(values, 1, name='low') 3301 3302 sess.run(variables.local_variables_initializer()) 3303 sess.run([update_op0, update_op1, update_op2]) 3304 3305 pcnt0, pcnt1, pcnt2 = sess.run([pcnt0, pcnt1, pcnt2]) 3306 self.assertAlmostEqual(1.0, pcnt0, 5) 3307 self.assertAlmostEqual(0.75, pcnt1, 5) 3308 self.assertAlmostEqual(0.0, pcnt2, 5) 3309 3310 def testSomePresentOneUpdate(self): 3311 with self.test_session() as sess: 3312 values = constant_op.constant( 3313 [2, 4, 6, 8], shape=(1, 4), dtype=dtypes_lib.float32) 3314 weights = constant_op.constant( 3315 [1, 0, 0, 1], shape=(1, 4), dtype=dtypes_lib.float32) 3316 3317 pcnt0, update_op0 = metrics.percentage_below( 3318 values, 100, weights=weights, name='high') 3319 pcnt1, update_op1 = metrics.percentage_below( 3320 values, 7, weights=weights, name='medium') 3321 pcnt2, update_op2 = metrics.percentage_below( 3322 values, 1, weights=weights, name='low') 3323 3324 sess.run(variables.local_variables_initializer()) 3325 self.assertListEqual([1.0, 0.5, 0.0], 3326 sess.run([update_op0, update_op1, update_op2])) 3327 3328 pcnt0, pcnt1, pcnt2 = sess.run([pcnt0, pcnt1, pcnt2]) 3329 self.assertAlmostEqual(1.0, pcnt0, 5) 3330 self.assertAlmostEqual(0.5, pcnt1, 5) 3331 self.assertAlmostEqual(0.0, pcnt2, 5) 3332 3333 3334 class MeanIOUTest(test.TestCase): 3335 3336 def setUp(self): 3337 np.random.seed(1) 3338 ops.reset_default_graph() 3339 3340 def testVars(self): 3341 metrics.mean_iou( 3342 predictions=array_ops.ones([10, 1]), 3343 labels=array_ops.ones([10, 1]), 3344 num_classes=2) 3345 _assert_metric_variables(self, ('mean_iou/total_confusion_matrix:0',)) 3346 3347 def testMetricsCollections(self): 3348 my_collection_name = '__metrics__' 3349 mean_iou, _ = metrics.mean_iou( 3350 predictions=array_ops.ones([10, 1]), 3351 labels=array_ops.ones([10, 1]), 3352 num_classes=2, 3353 metrics_collections=[my_collection_name]) 3354 self.assertListEqual(ops.get_collection(my_collection_name), [mean_iou]) 3355 3356 def testUpdatesCollection(self): 3357 my_collection_name = '__updates__' 3358 _, update_op = metrics.mean_iou( 3359 predictions=array_ops.ones([10, 1]), 3360 labels=array_ops.ones([10, 1]), 3361 num_classes=2, 3362 updates_collections=[my_collection_name]) 3363 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3364 3365 def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self): 3366 predictions = array_ops.ones([10, 3]) 3367 labels = array_ops.ones([10, 4]) 3368 with self.assertRaises(ValueError): 3369 metrics.mean_iou(labels, predictions, num_classes=2) 3370 3371 def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self): 3372 predictions = array_ops.ones([10]) 3373 labels = array_ops.ones([10]) 3374 weights = array_ops.zeros([9]) 3375 with self.assertRaises(ValueError): 3376 metrics.mean_iou(labels, predictions, num_classes=2, weights=weights) 3377 3378 def testValueTensorIsIdempotent(self): 3379 num_classes = 3 3380 predictions = random_ops.random_uniform( 3381 [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1) 3382 labels = random_ops.random_uniform( 3383 [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1) 3384 mean_iou, update_op = metrics.mean_iou( 3385 labels, predictions, num_classes=num_classes) 3386 3387 with self.test_session() as sess: 3388 sess.run(variables.local_variables_initializer()) 3389 3390 # Run several updates. 3391 for _ in range(10): 3392 sess.run(update_op) 3393 3394 # Then verify idempotency. 3395 initial_mean_iou = mean_iou.eval() 3396 for _ in range(10): 3397 self.assertEqual(initial_mean_iou, mean_iou.eval()) 3398 3399 def testMultipleUpdates(self): 3400 num_classes = 3 3401 with self.test_session() as sess: 3402 # Create the queue that populates the predictions. 3403 preds_queue = data_flow_ops.FIFOQueue( 3404 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3405 _enqueue_vector(sess, preds_queue, [0]) 3406 _enqueue_vector(sess, preds_queue, [1]) 3407 _enqueue_vector(sess, preds_queue, [2]) 3408 _enqueue_vector(sess, preds_queue, [1]) 3409 _enqueue_vector(sess, preds_queue, [0]) 3410 predictions = preds_queue.dequeue() 3411 3412 # Create the queue that populates the labels. 3413 labels_queue = data_flow_ops.FIFOQueue( 3414 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3415 _enqueue_vector(sess, labels_queue, [0]) 3416 _enqueue_vector(sess, labels_queue, [1]) 3417 _enqueue_vector(sess, labels_queue, [1]) 3418 _enqueue_vector(sess, labels_queue, [2]) 3419 _enqueue_vector(sess, labels_queue, [1]) 3420 labels = labels_queue.dequeue() 3421 3422 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3423 3424 sess.run(variables.local_variables_initializer()) 3425 for _ in range(5): 3426 sess.run(update_op) 3427 desired_output = np.mean([1.0 / 2.0, 1.0 / 4.0, 0.]) 3428 self.assertEqual(desired_output, miou.eval()) 3429 3430 def testMultipleUpdatesWithWeights(self): 3431 num_classes = 2 3432 with self.test_session() as sess: 3433 # Create the queue that populates the predictions. 3434 preds_queue = data_flow_ops.FIFOQueue( 3435 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3436 _enqueue_vector(sess, preds_queue, [0]) 3437 _enqueue_vector(sess, preds_queue, [1]) 3438 _enqueue_vector(sess, preds_queue, [0]) 3439 _enqueue_vector(sess, preds_queue, [1]) 3440 _enqueue_vector(sess, preds_queue, [0]) 3441 _enqueue_vector(sess, preds_queue, [1]) 3442 predictions = preds_queue.dequeue() 3443 3444 # Create the queue that populates the labels. 3445 labels_queue = data_flow_ops.FIFOQueue( 3446 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3447 _enqueue_vector(sess, labels_queue, [0]) 3448 _enqueue_vector(sess, labels_queue, [1]) 3449 _enqueue_vector(sess, labels_queue, [1]) 3450 _enqueue_vector(sess, labels_queue, [0]) 3451 _enqueue_vector(sess, labels_queue, [0]) 3452 _enqueue_vector(sess, labels_queue, [1]) 3453 labels = labels_queue.dequeue() 3454 3455 # Create the queue that populates the weights. 3456 weights_queue = data_flow_ops.FIFOQueue( 3457 6, dtypes=dtypes_lib.float32, shapes=(1, 1)) 3458 _enqueue_vector(sess, weights_queue, [1.0]) 3459 _enqueue_vector(sess, weights_queue, [1.0]) 3460 _enqueue_vector(sess, weights_queue, [1.0]) 3461 _enqueue_vector(sess, weights_queue, [0.0]) 3462 _enqueue_vector(sess, weights_queue, [1.0]) 3463 _enqueue_vector(sess, weights_queue, [0.0]) 3464 weights = weights_queue.dequeue() 3465 3466 mean_iou, update_op = metrics.mean_iou( 3467 labels, predictions, num_classes, weights=weights) 3468 3469 variables.local_variables_initializer().run() 3470 for _ in range(6): 3471 sess.run(update_op) 3472 desired_output = np.mean([2.0 / 3.0, 1.0 / 2.0]) 3473 self.assertAlmostEqual(desired_output, mean_iou.eval()) 3474 3475 def testMultipleUpdatesWithMissingClass(self): 3476 # Test the case where there are no predicions and labels for 3477 # one class, and thus there is one row and one column with 3478 # zero entries in the confusion matrix. 3479 num_classes = 3 3480 with self.test_session() as sess: 3481 # Create the queue that populates the predictions. 3482 # There is no prediction for class 2. 3483 preds_queue = data_flow_ops.FIFOQueue( 3484 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3485 _enqueue_vector(sess, preds_queue, [0]) 3486 _enqueue_vector(sess, preds_queue, [1]) 3487 _enqueue_vector(sess, preds_queue, [1]) 3488 _enqueue_vector(sess, preds_queue, [1]) 3489 _enqueue_vector(sess, preds_queue, [0]) 3490 predictions = preds_queue.dequeue() 3491 3492 # Create the queue that populates the labels. 3493 # There is label for class 2. 3494 labels_queue = data_flow_ops.FIFOQueue( 3495 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3496 _enqueue_vector(sess, labels_queue, [0]) 3497 _enqueue_vector(sess, labels_queue, [1]) 3498 _enqueue_vector(sess, labels_queue, [1]) 3499 _enqueue_vector(sess, labels_queue, [0]) 3500 _enqueue_vector(sess, labels_queue, [1]) 3501 labels = labels_queue.dequeue() 3502 3503 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3504 3505 sess.run(variables.local_variables_initializer()) 3506 for _ in range(5): 3507 sess.run(update_op) 3508 desired_output = np.mean([1.0 / 3.0, 2.0 / 4.0]) 3509 self.assertAlmostEqual(desired_output, miou.eval()) 3510 3511 def testUpdateOpEvalIsAccumulatedConfusionMatrix(self): 3512 predictions = array_ops.concat( 3513 [ 3514 constant_op.constant( 3515 0, shape=[5]), constant_op.constant( 3516 1, shape=[5]) 3517 ], 3518 0) 3519 labels = array_ops.concat( 3520 [ 3521 constant_op.constant( 3522 0, shape=[3]), constant_op.constant( 3523 1, shape=[7]) 3524 ], 3525 0) 3526 num_classes = 2 3527 with self.test_session() as sess: 3528 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3529 sess.run(variables.local_variables_initializer()) 3530 confusion_matrix = update_op.eval() 3531 self.assertAllEqual([[3, 0], [2, 5]], confusion_matrix) 3532 desired_miou = np.mean([3. / 5., 5. / 7.]) 3533 self.assertAlmostEqual(desired_miou, miou.eval()) 3534 3535 def testAllCorrect(self): 3536 predictions = array_ops.zeros([40]) 3537 labels = array_ops.zeros([40]) 3538 num_classes = 1 3539 with self.test_session() as sess: 3540 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3541 sess.run(variables.local_variables_initializer()) 3542 self.assertEqual(40, update_op.eval()[0]) 3543 self.assertEqual(1.0, miou.eval()) 3544 3545 def testAllWrong(self): 3546 predictions = array_ops.zeros([40]) 3547 labels = array_ops.ones([40]) 3548 num_classes = 2 3549 with self.test_session() as sess: 3550 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3551 sess.run(variables.local_variables_initializer()) 3552 self.assertAllEqual([[0, 0], [40, 0]], update_op.eval()) 3553 self.assertEqual(0., miou.eval()) 3554 3555 def testResultsWithSomeMissing(self): 3556 predictions = array_ops.concat( 3557 [ 3558 constant_op.constant( 3559 0, shape=[5]), constant_op.constant( 3560 1, shape=[5]) 3561 ], 3562 0) 3563 labels = array_ops.concat( 3564 [ 3565 constant_op.constant( 3566 0, shape=[3]), constant_op.constant( 3567 1, shape=[7]) 3568 ], 3569 0) 3570 num_classes = 2 3571 weights = array_ops.concat( 3572 [ 3573 constant_op.constant( 3574 0, shape=[1]), constant_op.constant( 3575 1, shape=[8]), constant_op.constant( 3576 0, shape=[1]) 3577 ], 3578 0) 3579 with self.test_session() as sess: 3580 miou, update_op = metrics.mean_iou( 3581 labels, predictions, num_classes, weights=weights) 3582 sess.run(variables.local_variables_initializer()) 3583 self.assertAllEqual([[2, 0], [2, 4]], update_op.eval()) 3584 desired_miou = np.mean([2. / 4., 4. / 6.]) 3585 self.assertAlmostEqual(desired_miou, miou.eval()) 3586 3587 def testMissingClassInLabels(self): 3588 labels = constant_op.constant([ 3589 [[0, 0, 1, 1, 0, 0], 3590 [1, 0, 0, 0, 0, 1]], 3591 [[1, 1, 1, 1, 1, 1], 3592 [0, 0, 0, 0, 0, 0]]]) 3593 predictions = constant_op.constant([ 3594 [[0, 0, 2, 1, 1, 0], 3595 [0, 1, 2, 2, 0, 1]], 3596 [[0, 0, 2, 1, 1, 1], 3597 [1, 1, 2, 0, 0, 0]]]) 3598 num_classes = 3 3599 with self.test_session() as sess: 3600 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3601 sess.run(variables.local_variables_initializer()) 3602 self.assertAllEqual([[7, 4, 3], [3, 5, 2], [0, 0, 0]], update_op.eval()) 3603 self.assertAlmostEqual( 3604 1 / 3 * (7 / (7 + 3 + 7) + 5 / (5 + 4 + 5) + 0 / (0 + 5 + 0)), 3605 miou.eval()) 3606 3607 def testMissingClassOverallSmall(self): 3608 labels = constant_op.constant([0]) 3609 predictions = constant_op.constant([0]) 3610 num_classes = 2 3611 with self.test_session() as sess: 3612 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3613 sess.run(variables.local_variables_initializer()) 3614 self.assertAllEqual([[1, 0], [0, 0]], update_op.eval()) 3615 self.assertAlmostEqual(1, miou.eval()) 3616 3617 def testMissingClassOverallLarge(self): 3618 labels = constant_op.constant([ 3619 [[0, 0, 1, 1, 0, 0], 3620 [1, 0, 0, 0, 0, 1]], 3621 [[1, 1, 1, 1, 1, 1], 3622 [0, 0, 0, 0, 0, 0]]]) 3623 predictions = constant_op.constant([ 3624 [[0, 0, 1, 1, 0, 0], 3625 [1, 1, 0, 0, 1, 1]], 3626 [[0, 0, 0, 1, 1, 1], 3627 [1, 1, 1, 0, 0, 0]]]) 3628 num_classes = 3 3629 with self.test_session() as sess: 3630 miou, update_op = metrics.mean_iou(labels, predictions, num_classes) 3631 sess.run(variables.local_variables_initializer()) 3632 self.assertAllEqual([[9, 5, 0], [3, 7, 0], [0, 0, 0]], update_op.eval()) 3633 self.assertAlmostEqual( 3634 1 / 2 * (9 / (9 + 3 + 5) + 7 / (7 + 5 + 3)), miou.eval()) 3635 3636 3637 class MeanPerClassAccuracyTest(test.TestCase): 3638 3639 def setUp(self): 3640 np.random.seed(1) 3641 ops.reset_default_graph() 3642 3643 def testVars(self): 3644 metrics.mean_per_class_accuracy( 3645 predictions=array_ops.ones([10, 1]), 3646 labels=array_ops.ones([10, 1]), 3647 num_classes=2) 3648 _assert_metric_variables(self, ('mean_accuracy/count:0', 3649 'mean_accuracy/total:0')) 3650 3651 def testMetricsCollections(self): 3652 my_collection_name = '__metrics__' 3653 mean_accuracy, _ = metrics.mean_per_class_accuracy( 3654 predictions=array_ops.ones([10, 1]), 3655 labels=array_ops.ones([10, 1]), 3656 num_classes=2, 3657 metrics_collections=[my_collection_name]) 3658 self.assertListEqual( 3659 ops.get_collection(my_collection_name), [mean_accuracy]) 3660 3661 def testUpdatesCollection(self): 3662 my_collection_name = '__updates__' 3663 _, update_op = metrics.mean_per_class_accuracy( 3664 predictions=array_ops.ones([10, 1]), 3665 labels=array_ops.ones([10, 1]), 3666 num_classes=2, 3667 updates_collections=[my_collection_name]) 3668 self.assertListEqual(ops.get_collection(my_collection_name), [update_op]) 3669 3670 def testPredictionsAndLabelsOfDifferentSizeRaisesValueError(self): 3671 predictions = array_ops.ones([10, 3]) 3672 labels = array_ops.ones([10, 4]) 3673 with self.assertRaises(ValueError): 3674 metrics.mean_per_class_accuracy(labels, predictions, num_classes=2) 3675 3676 def testLabelsAndWeightsOfDifferentSizeRaisesValueError(self): 3677 predictions = array_ops.ones([10]) 3678 labels = array_ops.ones([10]) 3679 weights = array_ops.zeros([9]) 3680 with self.assertRaises(ValueError): 3681 metrics.mean_per_class_accuracy( 3682 labels, predictions, num_classes=2, weights=weights) 3683 3684 def testValueTensorIsIdempotent(self): 3685 num_classes = 3 3686 predictions = random_ops.random_uniform( 3687 [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1) 3688 labels = random_ops.random_uniform( 3689 [10], maxval=num_classes, dtype=dtypes_lib.int64, seed=1) 3690 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 3691 labels, predictions, num_classes=num_classes) 3692 3693 with self.test_session() as sess: 3694 sess.run(variables.local_variables_initializer()) 3695 3696 # Run several updates. 3697 for _ in range(10): 3698 sess.run(update_op) 3699 3700 # Then verify idempotency. 3701 initial_mean_accuracy = mean_accuracy.eval() 3702 for _ in range(10): 3703 self.assertEqual(initial_mean_accuracy, mean_accuracy.eval()) 3704 3705 num_classes = 3 3706 with self.test_session() as sess: 3707 # Create the queue that populates the predictions. 3708 preds_queue = data_flow_ops.FIFOQueue( 3709 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3710 _enqueue_vector(sess, preds_queue, [0]) 3711 _enqueue_vector(sess, preds_queue, [1]) 3712 _enqueue_vector(sess, preds_queue, [2]) 3713 _enqueue_vector(sess, preds_queue, [1]) 3714 _enqueue_vector(sess, preds_queue, [0]) 3715 predictions = preds_queue.dequeue() 3716 3717 # Create the queue that populates the labels. 3718 labels_queue = data_flow_ops.FIFOQueue( 3719 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3720 _enqueue_vector(sess, labels_queue, [0]) 3721 _enqueue_vector(sess, labels_queue, [1]) 3722 _enqueue_vector(sess, labels_queue, [1]) 3723 _enqueue_vector(sess, labels_queue, [2]) 3724 _enqueue_vector(sess, labels_queue, [1]) 3725 labels = labels_queue.dequeue() 3726 3727 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 3728 labels, predictions, num_classes) 3729 3730 sess.run(variables.local_variables_initializer()) 3731 for _ in range(5): 3732 sess.run(update_op) 3733 desired_output = np.mean([1.0, 1.0 / 3.0, 0.0]) 3734 self.assertAlmostEqual(desired_output, mean_accuracy.eval()) 3735 3736 def testMultipleUpdatesWithWeights(self): 3737 num_classes = 2 3738 with self.test_session() as sess: 3739 # Create the queue that populates the predictions. 3740 preds_queue = data_flow_ops.FIFOQueue( 3741 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3742 _enqueue_vector(sess, preds_queue, [0]) 3743 _enqueue_vector(sess, preds_queue, [1]) 3744 _enqueue_vector(sess, preds_queue, [0]) 3745 _enqueue_vector(sess, preds_queue, [1]) 3746 _enqueue_vector(sess, preds_queue, [0]) 3747 _enqueue_vector(sess, preds_queue, [1]) 3748 predictions = preds_queue.dequeue() 3749 3750 # Create the queue that populates the labels. 3751 labels_queue = data_flow_ops.FIFOQueue( 3752 6, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3753 _enqueue_vector(sess, labels_queue, [0]) 3754 _enqueue_vector(sess, labels_queue, [1]) 3755 _enqueue_vector(sess, labels_queue, [1]) 3756 _enqueue_vector(sess, labels_queue, [0]) 3757 _enqueue_vector(sess, labels_queue, [0]) 3758 _enqueue_vector(sess, labels_queue, [1]) 3759 labels = labels_queue.dequeue() 3760 3761 # Create the queue that populates the weights. 3762 weights_queue = data_flow_ops.FIFOQueue( 3763 6, dtypes=dtypes_lib.float32, shapes=(1, 1)) 3764 _enqueue_vector(sess, weights_queue, [1.0]) 3765 _enqueue_vector(sess, weights_queue, [0.5]) 3766 _enqueue_vector(sess, weights_queue, [1.0]) 3767 _enqueue_vector(sess, weights_queue, [0.0]) 3768 _enqueue_vector(sess, weights_queue, [1.0]) 3769 _enqueue_vector(sess, weights_queue, [0.0]) 3770 weights = weights_queue.dequeue() 3771 3772 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 3773 labels, predictions, num_classes, weights=weights) 3774 3775 variables.local_variables_initializer().run() 3776 for _ in range(6): 3777 sess.run(update_op) 3778 desired_output = np.mean([2.0 / 2.0, 0.5 / 1.5]) 3779 self.assertAlmostEqual(desired_output, mean_accuracy.eval()) 3780 3781 def testMultipleUpdatesWithMissingClass(self): 3782 # Test the case where there are no predicions and labels for 3783 # one class, and thus there is one row and one column with 3784 # zero entries in the confusion matrix. 3785 num_classes = 3 3786 with self.test_session() as sess: 3787 # Create the queue that populates the predictions. 3788 # There is no prediction for class 2. 3789 preds_queue = data_flow_ops.FIFOQueue( 3790 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3791 _enqueue_vector(sess, preds_queue, [0]) 3792 _enqueue_vector(sess, preds_queue, [1]) 3793 _enqueue_vector(sess, preds_queue, [1]) 3794 _enqueue_vector(sess, preds_queue, [1]) 3795 _enqueue_vector(sess, preds_queue, [0]) 3796 predictions = preds_queue.dequeue() 3797 3798 # Create the queue that populates the labels. 3799 # There is label for class 2. 3800 labels_queue = data_flow_ops.FIFOQueue( 3801 5, dtypes=dtypes_lib.int32, shapes=(1, 1)) 3802 _enqueue_vector(sess, labels_queue, [0]) 3803 _enqueue_vector(sess, labels_queue, [1]) 3804 _enqueue_vector(sess, labels_queue, [1]) 3805 _enqueue_vector(sess, labels_queue, [0]) 3806 _enqueue_vector(sess, labels_queue, [1]) 3807 labels = labels_queue.dequeue() 3808 3809 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 3810 labels, predictions, num_classes) 3811 3812 sess.run(variables.local_variables_initializer()) 3813 for _ in range(5): 3814 sess.run(update_op) 3815 desired_output = np.mean([1.0 / 2.0, 2.0 / 3.0, 0.]) 3816 self.assertAlmostEqual(desired_output, mean_accuracy.eval()) 3817 3818 def testAllCorrect(self): 3819 predictions = array_ops.zeros([40]) 3820 labels = array_ops.zeros([40]) 3821 num_classes = 1 3822 with self.test_session() as sess: 3823 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 3824 labels, predictions, num_classes) 3825 sess.run(variables.local_variables_initializer()) 3826 self.assertEqual(1.0, update_op.eval()[0]) 3827 self.assertEqual(1.0, mean_accuracy.eval()) 3828 3829 def testAllWrong(self): 3830 predictions = array_ops.zeros([40]) 3831 labels = array_ops.ones([40]) 3832 num_classes = 2 3833 with self.test_session() as sess: 3834 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 3835 labels, predictions, num_classes) 3836 sess.run(variables.local_variables_initializer()) 3837 self.assertAllEqual([0.0, 0.0], update_op.eval()) 3838 self.assertEqual(0., mean_accuracy.eval()) 3839 3840 def testResultsWithSomeMissing(self): 3841 predictions = array_ops.concat([ 3842 constant_op.constant(0, shape=[5]), constant_op.constant(1, shape=[5]) 3843 ], 0) 3844 labels = array_ops.concat([ 3845 constant_op.constant(0, shape=[3]), constant_op.constant(1, shape=[7]) 3846 ], 0) 3847 num_classes = 2 3848 weights = array_ops.concat([ 3849 constant_op.constant(0, shape=[1]), constant_op.constant(1, shape=[8]), 3850 constant_op.constant(0, shape=[1]) 3851 ], 0) 3852 with self.test_session() as sess: 3853 mean_accuracy, update_op = metrics.mean_per_class_accuracy( 3854 labels, predictions, num_classes, weights=weights) 3855 sess.run(variables.local_variables_initializer()) 3856 desired_accuracy = np.array([2. / 2., 4. / 6.], dtype=np.float32) 3857 self.assertAllEqual(desired_accuracy, update_op.eval()) 3858 desired_mean_accuracy = np.mean(desired_accuracy) 3859 self.assertAlmostEqual(desired_mean_accuracy, mean_accuracy.eval()) 3860 3861 3862 class FalseNegativesTest(test.TestCase): 3863 3864 def setUp(self): 3865 np.random.seed(1) 3866 ops.reset_default_graph() 3867 3868 def testVars(self): 3869 metrics.false_negatives( 3870 labels=(0, 1, 0, 1), 3871 predictions=(0, 0, 1, 1)) 3872 _assert_metric_variables(self, ('false_negatives/count:0',)) 3873 3874 def testUnweighted(self): 3875 labels = constant_op.constant(((0, 1, 0, 1, 0), 3876 (0, 0, 1, 1, 1), 3877 (1, 1, 1, 1, 0), 3878 (0, 0, 0, 0, 1))) 3879 predictions = constant_op.constant(((0, 0, 1, 1, 0), 3880 (1, 1, 1, 1, 1), 3881 (0, 1, 0, 1, 0), 3882 (1, 1, 1, 1, 1))) 3883 tn, tn_update_op = metrics.false_negatives( 3884 labels=labels, predictions=predictions) 3885 3886 with self.test_session() as sess: 3887 sess.run(variables.local_variables_initializer()) 3888 self.assertAllClose(0., tn.eval()) 3889 self.assertAllClose(3., tn_update_op.eval()) 3890 self.assertAllClose(3., tn.eval()) 3891 3892 def testWeighted(self): 3893 labels = constant_op.constant(((0, 1, 0, 1, 0), 3894 (0, 0, 1, 1, 1), 3895 (1, 1, 1, 1, 0), 3896 (0, 0, 0, 0, 1))) 3897 predictions = constant_op.constant(((0, 0, 1, 1, 0), 3898 (1, 1, 1, 1, 1), 3899 (0, 1, 0, 1, 0), 3900 (1, 1, 1, 1, 1))) 3901 weights = constant_op.constant((1., 1.5, 2., 2.5)) 3902 tn, tn_update_op = metrics.false_negatives( 3903 labels=labels, predictions=predictions, weights=weights) 3904 3905 with self.test_session() as sess: 3906 sess.run(variables.local_variables_initializer()) 3907 self.assertAllClose(0., tn.eval()) 3908 self.assertAllClose(5., tn_update_op.eval()) 3909 self.assertAllClose(5., tn.eval()) 3910 3911 3912 class FalseNegativesAtThresholdsTest(test.TestCase): 3913 3914 def setUp(self): 3915 np.random.seed(1) 3916 ops.reset_default_graph() 3917 3918 def testVars(self): 3919 metrics.false_negatives_at_thresholds( 3920 predictions=array_ops.ones((10, 1)), 3921 labels=array_ops.ones((10, 1)), 3922 thresholds=[0.15, 0.5, 0.85]) 3923 _assert_metric_variables(self, ('false_negatives/false_negatives:0',)) 3924 3925 def testUnweighted(self): 3926 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 3927 (0.2, 0.9, 0.7, 0.6), 3928 (0.1, 0.2, 0.4, 0.3))) 3929 labels = constant_op.constant(((0, 1, 1, 0), 3930 (1, 0, 0, 0), 3931 (0, 0, 0, 0))) 3932 fn, fn_update_op = metrics.false_negatives_at_thresholds( 3933 predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85]) 3934 3935 with self.test_session() as sess: 3936 sess.run(variables.local_variables_initializer()) 3937 self.assertAllEqual((0, 0, 0), fn.eval()) 3938 self.assertAllEqual((0, 2, 3), fn_update_op.eval()) 3939 self.assertAllEqual((0, 2, 3), fn.eval()) 3940 3941 def testWeighted(self): 3942 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 3943 (0.2, 0.9, 0.7, 0.6), 3944 (0.1, 0.2, 0.4, 0.3))) 3945 labels = constant_op.constant(((0, 1, 1, 0), 3946 (1, 0, 0, 0), 3947 (0, 0, 0, 0))) 3948 fn, fn_update_op = metrics.false_negatives_at_thresholds( 3949 predictions=predictions, 3950 labels=labels, 3951 weights=((3.0,), (5.0,), (7.0,)), 3952 thresholds=[0.15, 0.5, 0.85]) 3953 3954 with self.test_session() as sess: 3955 sess.run(variables.local_variables_initializer()) 3956 self.assertAllEqual((0.0, 0.0, 0.0), fn.eval()) 3957 self.assertAllEqual((0.0, 8.0, 11.0), fn_update_op.eval()) 3958 self.assertAllEqual((0.0, 8.0, 11.0), fn.eval()) 3959 3960 3961 class FalsePositivesTest(test.TestCase): 3962 3963 def setUp(self): 3964 np.random.seed(1) 3965 ops.reset_default_graph() 3966 3967 def testVars(self): 3968 metrics.false_positives( 3969 labels=(0, 1, 0, 1), 3970 predictions=(0, 0, 1, 1)) 3971 _assert_metric_variables(self, ('false_positives/count:0',)) 3972 3973 def testUnweighted(self): 3974 labels = constant_op.constant(((0, 1, 0, 1, 0), 3975 (0, 0, 1, 1, 1), 3976 (1, 1, 1, 1, 0), 3977 (0, 0, 0, 0, 1))) 3978 predictions = constant_op.constant(((0, 0, 1, 1, 0), 3979 (1, 1, 1, 1, 1), 3980 (0, 1, 0, 1, 0), 3981 (1, 1, 1, 1, 1))) 3982 tn, tn_update_op = metrics.false_positives( 3983 labels=labels, predictions=predictions) 3984 3985 with self.test_session() as sess: 3986 sess.run(variables.local_variables_initializer()) 3987 self.assertAllClose(0., tn.eval()) 3988 self.assertAllClose(7., tn_update_op.eval()) 3989 self.assertAllClose(7., tn.eval()) 3990 3991 def testWeighted(self): 3992 labels = constant_op.constant(((0, 1, 0, 1, 0), 3993 (0, 0, 1, 1, 1), 3994 (1, 1, 1, 1, 0), 3995 (0, 0, 0, 0, 1))) 3996 predictions = constant_op.constant(((0, 0, 1, 1, 0), 3997 (1, 1, 1, 1, 1), 3998 (0, 1, 0, 1, 0), 3999 (1, 1, 1, 1, 1))) 4000 weights = constant_op.constant((1., 1.5, 2., 2.5)) 4001 tn, tn_update_op = metrics.false_positives( 4002 labels=labels, predictions=predictions, weights=weights) 4003 4004 with self.test_session() as sess: 4005 sess.run(variables.local_variables_initializer()) 4006 self.assertAllClose(0., tn.eval()) 4007 self.assertAllClose(14., tn_update_op.eval()) 4008 self.assertAllClose(14., tn.eval()) 4009 4010 4011 class FalsePositivesAtThresholdsTest(test.TestCase): 4012 4013 def setUp(self): 4014 np.random.seed(1) 4015 ops.reset_default_graph() 4016 4017 def testVars(self): 4018 metrics.false_positives_at_thresholds( 4019 predictions=array_ops.ones((10, 1)), 4020 labels=array_ops.ones((10, 1)), 4021 thresholds=[0.15, 0.5, 0.85]) 4022 _assert_metric_variables(self, ('false_positives/false_positives:0',)) 4023 4024 def testUnweighted(self): 4025 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4026 (0.2, 0.9, 0.7, 0.6), 4027 (0.1, 0.2, 0.4, 0.3))) 4028 labels = constant_op.constant(((0, 1, 1, 0), 4029 (1, 0, 0, 0), 4030 (0, 0, 0, 0))) 4031 fp, fp_update_op = metrics.false_positives_at_thresholds( 4032 predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85]) 4033 4034 with self.test_session() as sess: 4035 sess.run(variables.local_variables_initializer()) 4036 self.assertAllEqual((0, 0, 0), fp.eval()) 4037 self.assertAllEqual((7, 4, 2), fp_update_op.eval()) 4038 self.assertAllEqual((7, 4, 2), fp.eval()) 4039 4040 def testWeighted(self): 4041 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4042 (0.2, 0.9, 0.7, 0.6), 4043 (0.1, 0.2, 0.4, 0.3))) 4044 labels = constant_op.constant(((0, 1, 1, 0), 4045 (1, 0, 0, 0), 4046 (0, 0, 0, 0))) 4047 fp, fp_update_op = metrics.false_positives_at_thresholds( 4048 predictions=predictions, 4049 labels=labels, 4050 weights=((1.0, 2.0, 3.0, 5.0), 4051 (7.0, 11.0, 13.0, 17.0), 4052 (19.0, 23.0, 29.0, 31.0)), 4053 thresholds=[0.15, 0.5, 0.85]) 4054 4055 with self.test_session() as sess: 4056 sess.run(variables.local_variables_initializer()) 4057 self.assertAllEqual((0.0, 0.0, 0.0), fp.eval()) 4058 self.assertAllEqual((125.0, 42.0, 12.0), fp_update_op.eval()) 4059 self.assertAllEqual((125.0, 42.0, 12.0), fp.eval()) 4060 4061 4062 class TrueNegativesTest(test.TestCase): 4063 4064 def setUp(self): 4065 np.random.seed(1) 4066 ops.reset_default_graph() 4067 4068 def testVars(self): 4069 metrics.true_negatives( 4070 labels=(0, 1, 0, 1), 4071 predictions=(0, 0, 1, 1)) 4072 _assert_metric_variables(self, ('true_negatives/count:0',)) 4073 4074 def testUnweighted(self): 4075 labels = constant_op.constant(((0, 1, 0, 1, 0), 4076 (0, 0, 1, 1, 1), 4077 (1, 1, 1, 1, 0), 4078 (0, 0, 0, 0, 1))) 4079 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4080 (1, 1, 1, 1, 1), 4081 (0, 1, 0, 1, 0), 4082 (1, 1, 1, 1, 1))) 4083 tn, tn_update_op = metrics.true_negatives( 4084 labels=labels, predictions=predictions) 4085 4086 with self.test_session() as sess: 4087 sess.run(variables.local_variables_initializer()) 4088 self.assertAllClose(0., tn.eval()) 4089 self.assertAllClose(3., tn_update_op.eval()) 4090 self.assertAllClose(3., tn.eval()) 4091 4092 def testWeighted(self): 4093 labels = constant_op.constant(((0, 1, 0, 1, 0), 4094 (0, 0, 1, 1, 1), 4095 (1, 1, 1, 1, 0), 4096 (0, 0, 0, 0, 1))) 4097 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4098 (1, 1, 1, 1, 1), 4099 (0, 1, 0, 1, 0), 4100 (1, 1, 1, 1, 1))) 4101 weights = constant_op.constant((1., 1.5, 2., 2.5)) 4102 tn, tn_update_op = metrics.true_negatives( 4103 labels=labels, predictions=predictions, weights=weights) 4104 4105 with self.test_session() as sess: 4106 sess.run(variables.local_variables_initializer()) 4107 self.assertAllClose(0., tn.eval()) 4108 self.assertAllClose(4., tn_update_op.eval()) 4109 self.assertAllClose(4., tn.eval()) 4110 4111 4112 class TrueNegativesAtThresholdsTest(test.TestCase): 4113 4114 def setUp(self): 4115 np.random.seed(1) 4116 ops.reset_default_graph() 4117 4118 def testVars(self): 4119 metrics.true_negatives_at_thresholds( 4120 predictions=array_ops.ones((10, 1)), 4121 labels=array_ops.ones((10, 1)), 4122 thresholds=[0.15, 0.5, 0.85]) 4123 _assert_metric_variables(self, ('true_negatives/true_negatives:0',)) 4124 4125 def testUnweighted(self): 4126 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4127 (0.2, 0.9, 0.7, 0.6), 4128 (0.1, 0.2, 0.4, 0.3))) 4129 labels = constant_op.constant(((0, 1, 1, 0), 4130 (1, 0, 0, 0), 4131 (0, 0, 0, 0))) 4132 tn, tn_update_op = metrics.true_negatives_at_thresholds( 4133 predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85]) 4134 4135 with self.test_session() as sess: 4136 sess.run(variables.local_variables_initializer()) 4137 self.assertAllEqual((0, 0, 0), tn.eval()) 4138 self.assertAllEqual((2, 5, 7), tn_update_op.eval()) 4139 self.assertAllEqual((2, 5, 7), tn.eval()) 4140 4141 def testWeighted(self): 4142 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4143 (0.2, 0.9, 0.7, 0.6), 4144 (0.1, 0.2, 0.4, 0.3))) 4145 labels = constant_op.constant(((0, 1, 1, 0), 4146 (1, 0, 0, 0), 4147 (0, 0, 0, 0))) 4148 tn, tn_update_op = metrics.true_negatives_at_thresholds( 4149 predictions=predictions, 4150 labels=labels, 4151 weights=((0.0, 2.0, 3.0, 5.0),), 4152 thresholds=[0.15, 0.5, 0.85]) 4153 4154 with self.test_session() as sess: 4155 sess.run(variables.local_variables_initializer()) 4156 self.assertAllEqual((0.0, 0.0, 0.0), tn.eval()) 4157 self.assertAllEqual((5.0, 15.0, 23.0), tn_update_op.eval()) 4158 self.assertAllEqual((5.0, 15.0, 23.0), tn.eval()) 4159 4160 4161 class TruePositivesTest(test.TestCase): 4162 4163 def setUp(self): 4164 np.random.seed(1) 4165 ops.reset_default_graph() 4166 4167 def testVars(self): 4168 metrics.true_positives( 4169 labels=(0, 1, 0, 1), 4170 predictions=(0, 0, 1, 1)) 4171 _assert_metric_variables(self, ('true_positives/count:0',)) 4172 4173 def testUnweighted(self): 4174 labels = constant_op.constant(((0, 1, 0, 1, 0), 4175 (0, 0, 1, 1, 1), 4176 (1, 1, 1, 1, 0), 4177 (0, 0, 0, 0, 1))) 4178 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4179 (1, 1, 1, 1, 1), 4180 (0, 1, 0, 1, 0), 4181 (1, 1, 1, 1, 1))) 4182 tn, tn_update_op = metrics.true_positives( 4183 labels=labels, predictions=predictions) 4184 4185 with self.test_session() as sess: 4186 sess.run(variables.local_variables_initializer()) 4187 self.assertAllClose(0., tn.eval()) 4188 self.assertAllClose(7., tn_update_op.eval()) 4189 self.assertAllClose(7., tn.eval()) 4190 4191 def testWeighted(self): 4192 labels = constant_op.constant(((0, 1, 0, 1, 0), 4193 (0, 0, 1, 1, 1), 4194 (1, 1, 1, 1, 0), 4195 (0, 0, 0, 0, 1))) 4196 predictions = constant_op.constant(((0, 0, 1, 1, 0), 4197 (1, 1, 1, 1, 1), 4198 (0, 1, 0, 1, 0), 4199 (1, 1, 1, 1, 1))) 4200 weights = constant_op.constant((1., 1.5, 2., 2.5)) 4201 tn, tn_update_op = metrics.true_positives( 4202 labels=labels, predictions=predictions, weights=weights) 4203 4204 with self.test_session() as sess: 4205 sess.run(variables.local_variables_initializer()) 4206 self.assertAllClose(0., tn.eval()) 4207 self.assertAllClose(12., tn_update_op.eval()) 4208 self.assertAllClose(12., tn.eval()) 4209 4210 4211 class TruePositivesAtThresholdsTest(test.TestCase): 4212 4213 def setUp(self): 4214 np.random.seed(1) 4215 ops.reset_default_graph() 4216 4217 def testVars(self): 4218 metrics.true_positives_at_thresholds( 4219 predictions=array_ops.ones((10, 1)), 4220 labels=array_ops.ones((10, 1)), 4221 thresholds=[0.15, 0.5, 0.85]) 4222 _assert_metric_variables(self, ('true_positives/true_positives:0',)) 4223 4224 def testUnweighted(self): 4225 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4226 (0.2, 0.9, 0.7, 0.6), 4227 (0.1, 0.2, 0.4, 0.3))) 4228 labels = constant_op.constant(((0, 1, 1, 0), 4229 (1, 0, 0, 0), 4230 (0, 0, 0, 0))) 4231 tp, tp_update_op = metrics.true_positives_at_thresholds( 4232 predictions=predictions, labels=labels, thresholds=[0.15, 0.5, 0.85]) 4233 4234 with self.test_session() as sess: 4235 sess.run(variables.local_variables_initializer()) 4236 self.assertAllEqual((0, 0, 0), tp.eval()) 4237 self.assertAllEqual((3, 1, 0), tp_update_op.eval()) 4238 self.assertAllEqual((3, 1, 0), tp.eval()) 4239 4240 def testWeighted(self): 4241 predictions = constant_op.constant(((0.9, 0.2, 0.8, 0.1), 4242 (0.2, 0.9, 0.7, 0.6), 4243 (0.1, 0.2, 0.4, 0.3))) 4244 labels = constant_op.constant(((0, 1, 1, 0), 4245 (1, 0, 0, 0), 4246 (0, 0, 0, 0))) 4247 tp, tp_update_op = metrics.true_positives_at_thresholds( 4248 predictions=predictions, labels=labels, weights=37.0, 4249 thresholds=[0.15, 0.5, 0.85]) 4250 4251 with self.test_session() as sess: 4252 sess.run(variables.local_variables_initializer()) 4253 self.assertAllEqual((0.0, 0.0, 0.0), tp.eval()) 4254 self.assertAllEqual((111.0, 37.0, 0.0), tp_update_op.eval()) 4255 self.assertAllEqual((111.0, 37.0, 0.0), tp.eval()) 4256 4257 4258 if __name__ == '__main__': 4259 test.main() 4260