1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for head.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 import six 23 24 from tensorflow.contrib.estimator.python.estimator import head as head_lib 25 from tensorflow.core.framework import summary_pb2 26 from tensorflow.python.estimator import model_fn 27 from tensorflow.python.estimator.canned import metric_keys 28 from tensorflow.python.estimator.canned import prediction_keys 29 from tensorflow.python.framework import constant_op 30 from tensorflow.python.framework import dtypes 31 from tensorflow.python.framework import errors 32 from tensorflow.python.framework import ops 33 from tensorflow.python.framework import sparse_tensor 34 from tensorflow.python.ops import array_ops 35 from tensorflow.python.ops import control_flow_ops 36 from tensorflow.python.ops import math_ops 37 from tensorflow.python.ops import string_ops 38 from tensorflow.python.ops.losses import losses 39 from tensorflow.python.platform import test 40 from tensorflow.python.saved_model import signature_constants 41 from tensorflow.python.training import monitored_session 42 43 44 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 45 46 47 def _initialize_variables(test_case, scaffold): 48 scaffold.finalize() 49 test_case.assertIsNone(scaffold.init_feed_dict) 50 test_case.assertIsNone(scaffold.init_fn) 51 scaffold.init_op.run() 52 scaffold.ready_for_local_init_op.eval() 53 scaffold.local_init_op.run() 54 scaffold.ready_op.eval() 55 test_case.assertIsNotNone(scaffold.saver) 56 57 58 def _assert_simple_summaries(test_case, expected_summaries, summary_str, 59 tol=1e-6): 60 """Assert summary the specified simple values. 61 62 Args: 63 test_case: test case. 64 expected_summaries: Dict of expected tags and simple values. 65 summary_str: Serialized `summary_pb2.Summary`. 66 tol: Tolerance for relative and absolute. 67 """ 68 summary = summary_pb2.Summary() 69 summary.ParseFromString(summary_str) 70 test_case.assertAllClose(expected_summaries, { 71 v.tag: v.simple_value for v in summary.value 72 }, rtol=tol, atol=tol) 73 74 75 def _assert_no_hooks(test_case, spec): 76 test_case.assertAllEqual([], spec.training_chief_hooks) 77 test_case.assertAllEqual([], spec.training_hooks) 78 79 80 def _sigmoid(logits): 81 return 1 / (1 + np.exp(-logits)) 82 83 84 def _sigmoid_cross_entropy(labels, logits): 85 """Returns sigmoid cross entropy averaged over classes.""" 86 sigmoid_logits = _sigmoid(logits) 87 unreduced_result = ( 88 -labels * np.log(sigmoid_logits) 89 -(1 - labels) * np.log(1 - sigmoid_logits)) 90 # Mean over classes 91 return np.mean(unreduced_result, axis=-1, keepdims=True) 92 93 94 class MultiLabelHead(test.TestCase): 95 96 def setUp(self): 97 ops.reset_default_graph() 98 99 def test_n_classes_is_none(self): 100 with self.assertRaisesRegexp( 101 ValueError, 102 r'n_classes must be > 1 for multi-class classification\. Given: None'): 103 head_lib.multi_label_head(n_classes=None) 104 105 def test_n_classes_is_1(self): 106 with self.assertRaisesRegexp( 107 ValueError, 108 r'n_classes must be > 1 for multi-class classification\. Given: 1'): 109 head_lib.multi_label_head(n_classes=1) 110 111 def test_threshold_too_small(self): 112 with self.assertRaisesRegexp( 113 ValueError, 114 r'thresholds must be in \(0, 1\) range\. Given: 0\.0'): 115 head_lib.multi_label_head(n_classes=2, thresholds=[0., 0.5]) 116 117 def test_threshold_too_large(self): 118 with self.assertRaisesRegexp( 119 ValueError, 120 r'thresholds must be in \(0, 1\) range\. Given: 1\.0'): 121 head_lib.multi_label_head(n_classes=2, thresholds=[0.5, 1.0]) 122 123 def test_label_vocabulary_dict(self): 124 with self.assertRaisesRegexp( 125 ValueError, 126 r'label_vocabulary must be a list or tuple\. ' 127 r'Given type: <(type|class) \'dict\'>'): 128 head_lib.multi_label_head(n_classes=2, label_vocabulary={'foo': 'bar'}) 129 130 def test_label_vocabulary_wrong_size(self): 131 with self.assertRaisesRegexp( 132 ValueError, 133 r'Length of label_vocabulary must be n_classes \(3\). Given: 2'): 134 head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar']) 135 136 def test_invalid_loss_reduction(self): 137 with self.assertRaisesRegexp( 138 ValueError, r'Invalid loss_reduction: invalid_loss_reduction'): 139 head_lib.multi_label_head( 140 n_classes=3, loss_reduction='invalid_loss_reduction') 141 with self.assertRaisesRegexp( 142 ValueError, r'Invalid loss_reduction: none'): 143 head_lib.multi_label_head( 144 n_classes=3, loss_reduction=losses.Reduction.NONE) 145 146 def test_loss_fn_arg_labels_missing(self): 147 def _loss_fn(logits): 148 del logits # Unused 149 with self.assertRaisesRegexp( 150 ValueError, 151 r'loss_fn must contain argument: labels\. ' 152 r'Given arguments: \(\'logits\',\)'): 153 head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) 154 155 def test_loss_fn_arg_logits_missing(self): 156 def _loss_fn(labels): 157 del labels # unused 158 with self.assertRaisesRegexp( 159 ValueError, 160 r'loss_fn must contain argument: logits\. ' 161 r'Given arguments: \(\'labels\',\)'): 162 head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) 163 164 def test_loss_fn_arg_features_ok(self): 165 def _loss_fn(labels, logits, features): 166 del labels, logits, features # Unused 167 head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) 168 169 def test_loss_fn_arg_invalid(self): 170 def _loss_fn(labels, logits, name=None): 171 del labels, logits, name # Unused 172 with self.assertRaisesRegexp( 173 ValueError, 174 r'loss_fn has unexpected args: \[\'name\'\]'): 175 head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn) 176 177 def test_name(self): 178 head = head_lib.multi_label_head(n_classes=4, name='foo') 179 self.assertEqual('foo', head.name) 180 181 def test_predict(self): 182 n_classes = 4 183 head = head_lib.multi_label_head(n_classes) 184 self.assertEqual(n_classes, head.logits_dimension) 185 186 logits = np.array( 187 [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32) 188 expected_probabilities = _sigmoid(logits) 189 expected_export_classes = [[b'0', b'1', b'2', b'3']] * 2 190 191 spec = head.create_estimator_spec( 192 features={'x': np.array(((42,),), dtype=np.int32)}, 193 mode=model_fn.ModeKeys.PREDICT, 194 logits=logits) 195 196 self.assertItemsEqual( 197 (_DEFAULT_SERVING_KEY, 'predict', 'classification'), 198 spec.export_outputs.keys()) 199 200 # Assert predictions and export_outputs. 201 with self.test_session() as sess: 202 _initialize_variables(self, spec.scaffold) 203 self.assertIsNone(spec.scaffold.summary_op) 204 predictions = sess.run(spec.predictions) 205 self.assertAllClose(logits, 206 predictions[prediction_keys.PredictionKeys.LOGITS]) 207 self.assertAllClose( 208 expected_probabilities, 209 predictions[prediction_keys.PredictionKeys.PROBABILITIES]) 210 211 self.assertAllClose( 212 expected_probabilities, 213 sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) 214 self.assertAllEqual( 215 expected_export_classes, 216 sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes)) 217 218 def test_predict_with_label_vocabulary(self): 219 n_classes = 4 220 head = head_lib.multi_label_head( 221 n_classes, label_vocabulary=['foo', 'bar', 'foobar', 'barfoo']) 222 223 logits = np.array( 224 [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32) 225 expected_export_classes = [[b'foo', b'bar', b'foobar', b'barfoo']] * 2 226 227 spec = head.create_estimator_spec( 228 features={'x': np.array(((42,),), dtype=np.int32)}, 229 mode=model_fn.ModeKeys.PREDICT, 230 logits=logits) 231 232 with self.test_session() as sess: 233 _initialize_variables(self, spec.scaffold) 234 self.assertAllEqual( 235 expected_export_classes, 236 sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes)) 237 238 def test_weight_should_not_impact_prediction(self): 239 n_classes = 4 240 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') 241 self.assertEqual(n_classes, head.logits_dimension) 242 243 logits = np.array( 244 [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32) 245 expected_probabilities = _sigmoid(logits) 246 247 weights_2x1 = [[1.], [2.]] 248 spec = head.create_estimator_spec( 249 features={ 250 'x': np.array(((42,),), dtype=np.int32), 251 'example_weights': weights_2x1, 252 }, 253 mode=model_fn.ModeKeys.PREDICT, 254 logits=logits) 255 256 # Assert predictions and export_outputs. 257 with self.test_session() as sess: 258 _initialize_variables(self, spec.scaffold) 259 self.assertIsNone(spec.scaffold.summary_op) 260 predictions = sess.run(spec.predictions) 261 self.assertAllClose(logits, 262 predictions[prediction_keys.PredictionKeys.LOGITS]) 263 self.assertAllClose( 264 expected_probabilities, 265 predictions[prediction_keys.PredictionKeys.PROBABILITIES]) 266 267 def test_eval_create_loss(self): 268 """Tests head.create_loss for eval mode.""" 269 n_classes = 2 270 head = head_lib.multi_label_head(n_classes) 271 272 logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) 273 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 274 # loss = labels * -log(sigmoid(logits)) + 275 # (1 - labels) * -log(1 - sigmoid(logits)) 276 expected_training_loss = np.sum( 277 _sigmoid_cross_entropy(labels=labels, logits=logits)) 278 actual_training_loss = head.create_loss( 279 features={'x': np.array(((42,),), dtype=np.int32)}, 280 mode=model_fn.ModeKeys.EVAL, 281 logits=logits, 282 labels=labels)[0] 283 with self.test_session(): 284 _initialize_variables(self, monitored_session.Scaffold()) 285 self.assertAllClose(expected_training_loss, 286 actual_training_loss.eval()) 287 288 def test_eval_create_loss_large_logits(self): 289 """Tests head.create_loss for eval mode and large logits.""" 290 n_classes = 2 291 head = head_lib.multi_label_head(n_classes) 292 293 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 294 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 295 # loss = labels * -log(sigmoid(logits)) + 296 # (1 - labels) * -log(1 - sigmoid(logits)) 297 # For large logits, this is approximated as: 298 # loss = labels * (logits < 0) * (-logits) + 299 # (1 - labels) * (logits > 0) * logits 300 expected_training_loss = np.sum( 301 np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32)) 302 actual_training_loss = head.create_loss( 303 features={'x': np.array(((42,),), dtype=np.int32)}, 304 mode=model_fn.ModeKeys.EVAL, 305 logits=logits, 306 labels=labels)[0] 307 with self.test_session(): 308 _initialize_variables(self, monitored_session.Scaffold()) 309 self.assertAllClose( 310 expected_training_loss, actual_training_loss.eval(), atol=1e-4) 311 312 def test_eval_create_loss_labels_wrong_shape(self): 313 """Tests head.create_loss for eval mode when labels has the wrong shape.""" 314 n_classes = 2 315 head = head_lib.multi_label_head(n_classes) 316 317 logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) 318 labels_placeholder = array_ops.placeholder(dtype=dtypes.int64) 319 actual_training_loss = head.create_loss( 320 features={'x': np.array(((42,),), dtype=np.int32)}, 321 mode=model_fn.ModeKeys.EVAL, 322 logits=logits, 323 labels=labels_placeholder)[0] 324 with self.test_session(): 325 _initialize_variables(self, monitored_session.Scaffold()) 326 with self.assertRaisesRegexp( 327 errors.InvalidArgumentError, 328 r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'): 329 actual_training_loss.eval({ 330 labels_placeholder: np.array([[1], [1]], dtype=np.int64) 331 }) 332 with self.assertRaisesRegexp( 333 errors.InvalidArgumentError, 334 r'labels shape must be \[D0, D1, ... DN, 2\]\..*' 335 r'\[Received shape: \] \[2\]'): 336 actual_training_loss.eval({ 337 labels_placeholder: np.array([1, 1], dtype=np.int64) 338 }) 339 340 def test_eval_create_loss_loss_fn(self): 341 """Tests head.create_loss for eval mode and custom loss_fn.""" 342 loss = np.array([[1.], [2.]], dtype=np.float32) 343 logits_input = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 344 labels_input = np.array([[1, 0], [1, 1]], dtype=np.int64) 345 def _loss_fn(labels, logits): 346 check_labels = control_flow_ops.Assert( 347 math_ops.reduce_all(math_ops.equal(labels, labels_input)), 348 data=[labels]) 349 check_logits = control_flow_ops.Assert( 350 math_ops.reduce_all(math_ops.equal(logits, logits_input)), 351 data=[logits]) 352 with ops.control_dependencies([check_labels, check_logits]): 353 return constant_op.constant(loss) 354 head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) 355 356 actual_training_loss = head.create_loss( 357 features={'x': np.array(((42,),), dtype=np.int32)}, 358 mode=model_fn.ModeKeys.EVAL, 359 logits=logits_input, 360 labels=labels_input)[0] 361 with self.test_session(): 362 _initialize_variables(self, monitored_session.Scaffold()) 363 self.assertAllClose(np.sum(loss), actual_training_loss.eval()) 364 365 def test_eval_create_loss_loss_fn_wrong_shape(self): 366 """Tests custom loss_fn that returns Tensor of unexpected shape.""" 367 loss = np.array([1., 2.], dtype=np.float32) 368 def _loss_fn(labels, logits): 369 del labels, logits # Unused 370 return constant_op.constant(loss) 371 head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn) 372 373 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 374 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 375 actual_training_loss = head.create_loss( 376 features={'x': np.array(((42,),), dtype=np.int32)}, 377 mode=model_fn.ModeKeys.EVAL, 378 logits=logits, 379 labels=labels)[0] 380 with self.test_session(): 381 _initialize_variables(self, monitored_session.Scaffold()) 382 with self.assertRaisesRegexp( 383 errors.InvalidArgumentError, 384 r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] ' 385 r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2\]'): 386 actual_training_loss.eval() 387 388 def test_eval_labels_none(self): 389 """Tests that error is raised when labels is None.""" 390 head = head_lib.multi_label_head(n_classes=2) 391 392 with self.assertRaisesRegexp( 393 ValueError, r'You must provide a labels Tensor\. Given: None\.'): 394 head.create_estimator_spec( 395 features={'x': np.array(((42,),), dtype=np.int32)}, 396 mode=model_fn.ModeKeys.EVAL, 397 logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), 398 labels=None) 399 400 def _test_eval( 401 self, head, logits, labels, expected_loss, expected_metrics, 402 features=None, regularization_losses=None): 403 spec = head.create_estimator_spec( 404 features=features or {}, 405 mode=model_fn.ModeKeys.EVAL, 406 logits=logits, 407 labels=labels, 408 regularization_losses=regularization_losses) 409 410 # Assert spec contains expected tensors. 411 self.assertIsNotNone(spec.loss) 412 self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) 413 self.assertIsNone(spec.train_op) 414 self.assertIsNone(spec.export_outputs) 415 _assert_no_hooks(self, spec) 416 417 # Assert predictions, loss, and metrics. 418 tol = 1e-3 419 with self.test_session() as sess: 420 _initialize_variables(self, spec.scaffold) 421 self.assertIsNone(spec.scaffold.summary_op) 422 value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} 423 update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} 424 loss, metrics = sess.run((spec.loss, update_ops)) 425 self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) 426 # Check results of both update (in `metrics`) and value ops. 427 self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol) 428 self.assertAllClose( 429 expected_metrics, {k: value_ops[k].eval() for k in value_ops}, 430 rtol=tol, 431 atol=tol) 432 433 def test_eval(self): 434 n_classes = 2 435 head = head_lib.multi_label_head(n_classes) 436 logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) 437 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 438 # loss = labels * -log(sigmoid(logits)) + 439 # (1 - labels) * -log(1 - sigmoid(logits)) 440 # Sum over examples. 441 expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) 442 keys = metric_keys.MetricKeys 443 expected_metrics = { 444 # Average loss over examples. 445 keys.LOSS_MEAN: expected_loss / 2, 446 # auc and auc_pr cannot be reliably calculated for only 4 samples, but 447 # this assert tests that the algorithm remains consistent. 448 keys.AUC: 0.3333, 449 keys.AUC_PR: 0.5972, 450 } 451 self._test_eval( 452 head=head, 453 logits=logits, 454 labels=labels, 455 expected_loss=expected_loss, 456 expected_metrics=expected_metrics) 457 458 def test_eval_sparse_labels(self): 459 n_classes = 2 460 head = head_lib.multi_label_head(n_classes) 461 logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) 462 # Equivalent to multi_hot = [[1, 0], [1, 1]] 463 labels = sparse_tensor.SparseTensor( 464 values=[0, 0, 1], 465 indices=[[0, 0], [1, 0], [1, 1]], 466 dense_shape=[2, 2]) 467 labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) 468 # loss = labels * -log(sigmoid(logits)) + 469 # (1 - labels) * -log(1 - sigmoid(logits)) 470 # Sum over examples. 471 expected_loss = ( 472 np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) 473 ) 474 keys = metric_keys.MetricKeys 475 expected_metrics = { 476 # Average loss over examples. 477 keys.LOSS_MEAN: expected_loss / 2, 478 # auc and auc_pr cannot be reliably calculated for only 4 samples, but 479 # this assert tests that the algorithm remains consistent. 480 keys.AUC: 0.3333, 481 keys.AUC_PR: 0.5972, 482 } 483 self._test_eval( 484 head=head, 485 logits=logits, 486 labels=labels, 487 expected_loss=expected_loss, 488 expected_metrics=expected_metrics) 489 490 def test_eval_with_regularization_losses(self): 491 n_classes = 2 492 head = head_lib.multi_label_head( 493 n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) 494 logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) 495 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 496 regularization_losses = [1.5, 0.5] 497 expected_regularization_loss = 2. 498 # unregularized_loss = sum( 499 # labels * -log(sigmoid(logits)) + 500 # (1 - labels) * -log(1 - sigmoid(logits))) / batch_size 501 expected_unregularized_loss = np.sum( 502 _sigmoid_cross_entropy(labels=labels, logits=logits)) / 2. 503 expected_regularized_loss = ( 504 expected_unregularized_loss + expected_regularization_loss) 505 keys = metric_keys.MetricKeys 506 expected_metrics = { 507 keys.LOSS_MEAN: expected_unregularized_loss, 508 keys.LOSS_REGULARIZATION: expected_regularization_loss, 509 # auc and auc_pr cannot be reliably calculated for only 4 samples, but 510 # this assert tests that the algorithm remains consistent. 511 keys.AUC: 0.3333, 512 keys.AUC_PR: 0.5972, 513 } 514 self._test_eval( 515 head=head, 516 logits=logits, 517 labels=labels, 518 expected_loss=expected_regularized_loss, 519 expected_metrics=expected_metrics, 520 regularization_losses=regularization_losses) 521 522 def test_eval_with_label_vocabulary(self): 523 n_classes = 2 524 head = head_lib.multi_label_head( 525 n_classes, label_vocabulary=['class0', 'class1']) 526 logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) 527 # Equivalent to multi_hot = [[1, 0], [1, 1]] 528 labels = sparse_tensor.SparseTensor( 529 values=['class0', 'class0', 'class1'], 530 indices=[[0, 0], [1, 0], [1, 1]], 531 dense_shape=[2, 2]) 532 labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64) 533 # loss = labels * -log(sigmoid(logits)) + 534 # (1 - labels) * -log(1 - sigmoid(logits)) 535 # Sum over examples. 536 expected_loss = ( 537 np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits)) 538 ) 539 keys = metric_keys.MetricKeys 540 expected_metrics = { 541 # Average loss over examples. 542 keys.LOSS_MEAN: expected_loss / 2, 543 # auc and auc_pr cannot be reliably calculated for only 4 samples, but 544 # this assert tests that the algorithm remains consistent. 545 keys.AUC: 0.3333, 546 keys.AUC_PR: 0.5972, 547 } 548 self._test_eval( 549 head=head, 550 logits=logits, 551 labels=labels, 552 expected_loss=expected_loss, 553 expected_metrics=expected_metrics) 554 555 def test_eval_with_thresholds(self): 556 n_classes = 2 557 thresholds = [0.25, 0.5, 0.75] 558 head = head_lib.multi_label_head(n_classes, thresholds=thresholds) 559 560 logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32) 561 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 562 # loss = labels * -log(sigmoid(logits)) + 563 # (1 - labels) * -log(1 - sigmoid(logits)) 564 # Sum over examples. 565 expected_loss = ( 566 np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits)) 567 ) 568 569 keys = metric_keys.MetricKeys 570 expected_metrics = { 571 # Average loss over examples. 572 keys.LOSS_MEAN: expected_loss / 2, 573 # auc and auc_pr cannot be reliably calculated for only 4 samples, but 574 # this assert tests that the algorithm remains consistent. 575 keys.AUC: 0.3333, 576 keys.AUC_PR: 0.5972, 577 keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4., 578 keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3., 579 keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3., 580 keys.ACCURACY_AT_THRESHOLD % thresholds[1]: 1. / 4., 581 keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1. / 2., 582 keys.RECALL_AT_THRESHOLD % thresholds[1]: 1. / 3., 583 keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 2. / 4., 584 keys.PRECISION_AT_THRESHOLD % thresholds[2]: 1. / 1., 585 keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3., 586 } 587 588 self._test_eval( 589 head=head, 590 logits=logits, 591 labels=labels, 592 expected_loss=expected_loss, 593 expected_metrics=expected_metrics) 594 595 def test_eval_with_weights(self): 596 n_classes = 2 597 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') 598 599 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 600 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 601 # For large logits, sigmoid cross entropy loss is approximated as: 602 # loss = labels * (logits < 0) * (-logits) + 603 # (1 - labels) * (logits > 0) * logits => 604 # expected_unweighted_loss = [[10., 10.], [15., 0.]] 605 # Average over classes, weighted sum over examples. 606 expected_loss = 25. 607 608 spec = head.create_estimator_spec( 609 features={ 610 'x': np.array([[41], [42]], dtype=np.int32), 611 'example_weights': np.array([[1.], [2.]], dtype=np.float32), 612 }, 613 mode=model_fn.ModeKeys.EVAL, 614 logits=logits, 615 labels=labels) 616 617 keys = metric_keys.MetricKeys 618 expected_metrics = { 619 # Average loss over weighted examples. 620 keys.LOSS_MEAN: expected_loss / 3, 621 # auc and auc_pr cannot be reliably calculated for only 4 samples, but 622 # this assert tests that the algorithm remains consistent. 623 keys.AUC: 0.2000, 624 keys.AUC_PR: 0.5833, 625 } 626 627 # Assert spec contains expected tensors. 628 self.assertIsNotNone(spec.loss) 629 self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) 630 self.assertIsNone(spec.train_op) 631 self.assertIsNone(spec.export_outputs) 632 _assert_no_hooks(self, spec) 633 634 # Assert predictions, loss, and metrics. 635 tol = 1e-3 636 with self.test_session() as sess: 637 _initialize_variables(self, spec.scaffold) 638 self.assertIsNone(spec.scaffold.summary_op) 639 value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} 640 update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} 641 loss, metrics = sess.run((spec.loss, update_ops)) 642 self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) 643 # Check results of both update (in `metrics`) and value ops. 644 self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol) 645 self.assertAllClose( 646 expected_metrics, {k: value_ops[k].eval() for k in value_ops}, 647 rtol=tol, 648 atol=tol) 649 650 def test_train_create_loss_large_logits(self): 651 """Tests head.create_loss for train mode and large logits.""" 652 n_classes = 2 653 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') 654 655 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 656 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 657 weights = np.array([[1.], [2.]], dtype=np.float32) 658 # loss = labels * -log(sigmoid(logits)) + 659 # (1 - labels) * -log(1 - sigmoid(logits)) 660 # For large logits, this is approximated as: 661 # loss = labels * (logits < 0) * (-logits) + 662 # (1 - labels) * (logits > 0) * logits 663 expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] 664 expected_weights = [[1.], [2.]] 665 expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2. 666 training_loss, unreduced_loss, actual_weights, _ = head.create_loss( 667 features={ 668 'x': np.array(((42,),), dtype=np.int32), 669 'example_weights': weights 670 }, 671 mode=model_fn.ModeKeys.TRAIN, 672 logits=logits, 673 labels=labels) 674 with self.test_session(): 675 _initialize_variables(self, monitored_session.Scaffold()) 676 self.assertAllClose( 677 expected_training_loss, training_loss.eval(), atol=1e-4) 678 self.assertAllClose( 679 expected_unreduced_loss, unreduced_loss.eval(), atol=1e-4) 680 self.assertAllClose(expected_weights, actual_weights.eval()) 681 682 def test_train_create_loss_loss_reduction(self): 683 """Tests head.create_loss with loss_reduction.""" 684 n_classes = 2 685 head = head_lib.multi_label_head( 686 n_classes, weight_column='example_weights', 687 loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS) 688 689 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 690 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 691 weights = np.array([[1.], [2.]], dtype=np.float32) 692 # loss = labels * -log(sigmoid(logits)) + 693 # (1 - labels) * -log(1 - sigmoid(logits)) 694 # For large logits, this is approximated as: 695 # loss = labels * (logits < 0) * (-logits) + 696 # (1 - labels) * (logits > 0) * logits 697 expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]] 698 expected_weights = [[1.], [2.]] 699 expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2. 700 training_loss, unreduced_loss, actual_weights, _ = head.create_loss( 701 features={ 702 'x': np.array(((42,),), dtype=np.int32), 703 'example_weights': weights 704 }, 705 mode=model_fn.ModeKeys.TRAIN, 706 logits=logits, 707 labels=labels) 708 with self.test_session(): 709 _initialize_variables(self, monitored_session.Scaffold()) 710 self.assertAllClose( 711 expected_training_loss, training_loss.eval(), atol=1e-4) 712 self.assertAllClose( 713 expected_unreduced_loss, unreduced_loss.eval(), atol=1e-4) 714 self.assertAllClose(expected_weights, actual_weights.eval()) 715 716 def test_train_labels_none(self): 717 """Tests that error is raised when labels is None.""" 718 head = head_lib.multi_label_head(n_classes=2) 719 def _no_op_train_fn(loss): 720 del loss 721 return control_flow_ops.no_op() 722 723 with self.assertRaisesRegexp( 724 ValueError, r'You must provide a labels Tensor\. Given: None\.'): 725 head.create_estimator_spec( 726 features={'x': np.array(((42,),), dtype=np.int32)}, 727 mode=model_fn.ModeKeys.TRAIN, 728 logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), 729 labels=None, 730 train_op_fn=_no_op_train_fn) 731 732 def test_train_invalid_indicator_labels(self): 733 head = head_lib.multi_label_head(n_classes=2) 734 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 735 # The value 2 is outside the allowed range. 736 labels = np.array([[2, 0], [1, 1]], dtype=np.int64) 737 def _train_op_fn(loss): 738 del loss 739 return control_flow_ops.no_op() 740 741 spec = head.create_estimator_spec( 742 features={}, 743 mode=model_fn.ModeKeys.TRAIN, 744 logits=logits, 745 labels=labels, 746 train_op_fn=_train_op_fn) 747 with self.test_session() as sess: 748 _initialize_variables(self, spec.scaffold) 749 with self.assertRaisesRegexp( 750 errors.InvalidArgumentError, 751 r'labels must be an integer indicator Tensor with values in ' 752 r'\[0, 1\]'): 753 sess.run(spec.loss) 754 755 def test_train_invalid_sparse_labels(self): 756 head = head_lib.multi_label_head(n_classes=2) 757 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 758 # The value 2 is outside the allowed range. 759 labels = sparse_tensor.SparseTensor( 760 values=[2, 0, 1], 761 indices=[[0, 0], [1, 0], [1, 1]], 762 dense_shape=[2, 2]) 763 def _train_op_fn(loss): 764 del loss 765 return control_flow_ops.no_op() 766 767 spec = head.create_estimator_spec( 768 features={}, 769 mode=model_fn.ModeKeys.TRAIN, 770 logits=logits, 771 labels=labels, 772 train_op_fn=_train_op_fn) 773 with self.test_session() as sess: 774 _initialize_variables(self, spec.scaffold) 775 with self.assertRaisesRegexp( 776 errors.InvalidArgumentError, 777 r'labels must be an integer SparseTensor with values in \[0, 2\)'): 778 sess.run(spec.loss) 779 780 def _test_train(self, head, logits, labels, expected_loss): 781 expected_train_result = 'my_train_op' 782 def _train_op_fn(loss): 783 return string_ops.string_join( 784 [constant_op.constant(expected_train_result), 785 string_ops.as_string(loss, precision=3)]) 786 787 spec = head.create_estimator_spec( 788 features={'x': np.array(((42,),), dtype=np.int32)}, 789 mode=model_fn.ModeKeys.TRAIN, 790 logits=logits, 791 labels=labels, 792 train_op_fn=_train_op_fn) 793 794 self.assertIsNotNone(spec.loss) 795 self.assertEqual({}, spec.eval_metric_ops) 796 self.assertIsNotNone(spec.train_op) 797 self.assertIsNone(spec.export_outputs) 798 _assert_no_hooks(self, spec) 799 800 # Assert predictions, loss, train_op, and summaries. 801 tol = 1e-3 802 with self.test_session() as sess: 803 _initialize_variables(self, spec.scaffold) 804 self.assertIsNotNone(spec.scaffold.summary_op) 805 loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, 806 spec.scaffold.summary_op)) 807 self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) 808 self.assertEqual( 809 six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), 810 train_result) 811 _assert_simple_summaries(self, { 812 metric_keys.MetricKeys.LOSS: expected_loss, 813 # Average loss over examples. 814 metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2, 815 }, summary_str, tol) 816 817 def test_train(self): 818 head = head_lib.multi_label_head(n_classes=2) 819 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 820 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 821 # For large logits, sigmoid cross entropy loss is approximated as: 822 # loss = labels * (logits < 0) * (-logits) + 823 # (1 - labels) * (logits > 0) * logits => 824 # expected_unweighted_loss = [[10., 10.], [15., 0.]] 825 # Average over classes, sum over weights. 826 expected_loss = 17.5 827 self._test_train( 828 head=head, logits=logits, labels=labels, expected_loss=expected_loss) 829 830 def test_train_sparse_labels(self): 831 head = head_lib.multi_label_head(n_classes=2) 832 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 833 # Equivalent to multi_hot = [[1, 0], [1, 1]] 834 labels = sparse_tensor.SparseTensor( 835 values=[0, 0, 1], 836 indices=[[0, 0], [1, 0], [1, 1]], 837 dense_shape=[2, 2]) 838 # For large logits, sigmoid cross entropy loss is approximated as: 839 # loss = labels * (logits < 0) * (-logits) + 840 # (1 - labels) * (logits > 0) * logits => 841 # expected_unweighted_loss = [[10., 10.], [15., 0.]] 842 # Average over classes, sum over weights. 843 expected_loss = 17.5 844 self._test_train( 845 head=head, logits=logits, labels=labels, expected_loss=expected_loss) 846 847 def test_train_with_label_vocabulary(self): 848 head = head_lib.multi_label_head( 849 n_classes=2, label_vocabulary=['class0', 'class1']) 850 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 851 # Equivalent to multi_hot = [[1, 0], [1, 1]] 852 labels = sparse_tensor.SparseTensor( 853 values=['class0', 'class0', 'class1'], 854 indices=[[0, 0], [1, 0], [1, 1]], 855 dense_shape=[2, 2]) 856 # For large logits, sigmoid cross entropy loss is approximated as: 857 # loss = labels * (logits < 0) * (-logits) + 858 # (1 - labels) * (logits > 0) * logits => 859 # expected_unweighted_loss = [[10., 10.], [15., 0.]] 860 # Average over classes, sum over weights. 861 expected_loss = 17.5 862 self._test_train( 863 head=head, logits=logits, labels=labels, expected_loss=expected_loss) 864 865 def test_train_with_regularization_losses(self): 866 head = head_lib.multi_label_head( 867 n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE) 868 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 869 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 870 regularization_losses = [1.5, 0.5] 871 # For large logits, sigmoid cross entropy loss is approximated as: 872 # loss = labels * (logits < 0) * (-logits) + 873 # (1 - labels) * (logits > 0) * logits => 874 # expected_unweighted_loss = [[10., 10.], [15., 0.]] 875 # Average over classes and over batch and add regularization loss. 876 expected_loss = 35. / 4. + 2. 877 expected_summaries = { 878 metric_keys.MetricKeys.LOSS: expected_loss, 879 metric_keys.MetricKeys.LOSS_REGULARIZATION: 2., 880 } 881 expected_train_result = 'my_train_op' 882 def _train_op_fn(loss): 883 return string_ops.string_join( 884 [constant_op.constant(expected_train_result), 885 string_ops.as_string(loss, precision=3)]) 886 887 spec = head.create_estimator_spec( 888 features={'x': np.array(((42,),), dtype=np.int32)}, 889 mode=model_fn.ModeKeys.TRAIN, 890 logits=logits, 891 labels=labels, 892 train_op_fn=_train_op_fn, 893 regularization_losses=regularization_losses) 894 895 # Assert predictions, loss, train_op, and summaries. 896 tol = 1e-3 897 with self.test_session() as sess: 898 _initialize_variables(self, spec.scaffold) 899 self.assertIsNotNone(spec.scaffold.summary_op) 900 loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, 901 spec.scaffold.summary_op)) 902 self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) 903 self.assertEqual( 904 six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), 905 train_result) 906 _assert_simple_summaries(self, expected_summaries, summary_str, tol) 907 908 def test_train_with_weights(self): 909 n_classes = 2 910 head = head_lib.multi_label_head(n_classes, weight_column='example_weights') 911 912 logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32) 913 labels = np.array([[1, 0], [1, 1]], dtype=np.int64) 914 # For large logits, sigmoid cross entropy loss is approximated as: 915 # loss = labels * (logits < 0) * (-logits) + 916 # (1 - labels) * (logits > 0) * logits => 917 # expected_unweighted_loss = [[10., 10.], [15., 0.]] 918 # Average over classes, weighted sum over examples. 919 expected_loss = 25. 920 expected_train_result = 'my_train_op' 921 def _train_op_fn(loss): 922 return string_ops.string_join( 923 [constant_op.constant(expected_train_result), 924 string_ops.as_string(loss, precision=3)]) 925 926 spec = head.create_estimator_spec( 927 features={ 928 'x': np.array([[41], [42]], dtype=np.int32), 929 'example_weights': np.array([[1.], [2.]], dtype=np.float32), 930 }, 931 mode=model_fn.ModeKeys.TRAIN, 932 logits=logits, 933 labels=labels, 934 train_op_fn=_train_op_fn) 935 936 self.assertIsNotNone(spec.loss) 937 self.assertEqual({}, spec.eval_metric_ops) 938 self.assertIsNotNone(spec.train_op) 939 self.assertIsNone(spec.export_outputs) 940 _assert_no_hooks(self, spec) 941 942 # Assert predictions, loss, train_op, and summaries. 943 tol = 1e-3 944 with self.test_session() as sess: 945 _initialize_variables(self, spec.scaffold) 946 self.assertIsNotNone(spec.scaffold.summary_op) 947 loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, 948 spec.scaffold.summary_op)) 949 self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) 950 self.assertEqual( 951 six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), 952 train_result) 953 _assert_simple_summaries(self, { 954 metric_keys.MetricKeys.LOSS: expected_loss, 955 # Average loss over weighted examples. 956 metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3, 957 }, summary_str, tol) 958 959 def test_multi_dim_weighted_train_create_loss(self): 960 """Logits and labels of shape [2, 2, 3], weights [2, 2].""" 961 head = head_lib.multi_label_head(n_classes=3, weight_column='weights') 962 963 logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], 964 [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) 965 labels = np.array([[[1, 0, 0], [1, 0, 0]], 966 [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) 967 weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) 968 # unreduced_loss = 969 # [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 970 # = [[20/3, 10/3], [4, 8]] 971 expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]] 972 # weights are reshaped to [2, 2, 1] to match logits. 973 expected_weights = [[[1.], [1.5]], [[2.], [2.5]]] 974 # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 975 expected_training_loss = 39.6667 976 training_loss, unreduced_loss, actual_weights, _ = head.create_loss( 977 features={'weights': weights}, 978 mode=model_fn.ModeKeys.TRAIN, 979 logits=logits, 980 labels=labels) 981 atol = 1.e-3 982 with self.test_session(): 983 _initialize_variables(self, monitored_session.Scaffold()) 984 self.assertAllClose( 985 expected_training_loss, training_loss.eval(), atol=atol) 986 self.assertAllClose( 987 expected_unreduced_loss, unreduced_loss.eval(), atol=atol) 988 self.assertAllClose(expected_weights, actual_weights.eval()) 989 990 def test_multi_dim_weighted_train(self): 991 """Logits and labels of shape [2, 2, 3], weights [2, 2].""" 992 head = head_lib.multi_label_head(n_classes=3, weight_column='weights') 993 994 logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], 995 [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) 996 labels = np.array([[[1, 0, 0], [1, 0, 0]], 997 [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) 998 weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) 999 # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 1000 # = [[20/3, 10/3], [4, 8]] 1001 # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 1002 expected_loss = 39.6667 1003 expected_train_result = 'my_train_op' 1004 def _train_op_fn(loss): 1005 return string_ops.string_join( 1006 [constant_op.constant(expected_train_result), 1007 string_ops.as_string(loss, precision=3)]) 1008 1009 spec = head.create_estimator_spec( 1010 features={'weights': weights}, 1011 mode=model_fn.ModeKeys.TRAIN, 1012 logits=logits, 1013 labels=labels, 1014 train_op_fn=_train_op_fn) 1015 1016 atol = 1.e-3 1017 with self.test_session() as sess: 1018 _initialize_variables(self, monitored_session.Scaffold()) 1019 loss, train_result = sess.run((spec.loss, spec.train_op)) 1020 self.assertAllClose(expected_loss, loss, atol=atol) 1021 self.assertEqual( 1022 six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), 1023 train_result) 1024 1025 def test_multi_dim_weights_wrong_inner_dim(self): 1026 """Logits and labels of shape [2, 2, 3], weights [2, 1].""" 1027 head = head_lib.multi_label_head(n_classes=3, weight_column='weights') 1028 1029 logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], 1030 [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) 1031 labels = np.array([[[1, 0, 0], [1, 0, 0]], 1032 [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) 1033 weights = np.array([[1.], [2.]], dtype=np.float32) 1034 def _train_op_fn(loss): 1035 del loss 1036 return control_flow_ops.no_op() 1037 1038 spec = head.create_estimator_spec( 1039 features={'weights': weights}, 1040 mode=model_fn.ModeKeys.TRAIN, 1041 logits=logits, 1042 labels=labels, 1043 train_op_fn=_train_op_fn) 1044 with self.test_session(): 1045 _initialize_variables(self, monitored_session.Scaffold()) 1046 with self.assertRaisesRegexp( 1047 errors.InvalidArgumentError, 1048 r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 1\]'): 1049 spec.loss.eval() 1050 1051 def test_multi_dim_weights_wrong_outer_dim(self): 1052 """Logits and labels of shape [2, 2, 3], weights [2, 2, 3].""" 1053 head = head_lib.multi_label_head(n_classes=3, weight_column='weights') 1054 1055 logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], 1056 [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) 1057 labels = np.array([[[1, 0, 0], [1, 0, 0]], 1058 [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) 1059 weights = np.array([[[1., 1., 1.], [1.5, 1.5, 1.5]], 1060 [[2., 2., 2.], [2.5, 2.5, 2.5]]], dtype=np.float32) 1061 weights_placeholder = array_ops.placeholder(dtype=dtypes.float32) 1062 def _train_op_fn(loss): 1063 del loss 1064 return control_flow_ops.no_op() 1065 1066 spec = head.create_estimator_spec( 1067 features={'weights': weights_placeholder}, 1068 mode=model_fn.ModeKeys.TRAIN, 1069 logits=logits, 1070 labels=labels, 1071 train_op_fn=_train_op_fn) 1072 with self.test_session(): 1073 _initialize_variables(self, monitored_session.Scaffold()) 1074 with self.assertRaisesRegexp( 1075 errors.InvalidArgumentError, 1076 r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 2 3\]'): 1077 spec.loss.eval({weights_placeholder: weights}) 1078 1079 def test_multi_dim_weighted_eval(self): 1080 """Logits and labels of shape [2, 2, 3], weights [2, 2].""" 1081 head = head_lib.multi_label_head(n_classes=3, weight_column='weights') 1082 1083 logits = np.array([[[-10., 10., -10.], [10., -10., 10.]], 1084 [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32) 1085 labels = np.array([[[1, 0, 0], [1, 0, 0]], 1086 [[0, 1, 1], [0, 1, 1]]], dtype=np.int64) 1087 weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32) 1088 # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3 1089 # = [[20/3, 10/3], [4, 8]] 1090 # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667 1091 expected_loss = 39.6667 1092 keys = metric_keys.MetricKeys 1093 expected_metrics = { 1094 keys.LOSS_MEAN: expected_loss / np.sum(weights), 1095 # auc and auc_pr cannot be reliably calculated for only 4 samples, but 1096 # this assert tests that the algorithm remains consistent. 1097 keys.AUC: 0.4977, 1098 keys.AUC_PR: 0.4037, 1099 } 1100 self._test_eval( 1101 head=head, 1102 features={'weights': weights}, 1103 logits=logits, 1104 labels=labels, 1105 expected_loss=expected_loss, 1106 expected_metrics=expected_metrics) 1107 1108 1109 if __name__ == '__main__': 1110 test.main() 1111