1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for head.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 import six 23 24 from tensorflow.contrib.estimator.python.estimator import head as head_lib 25 from tensorflow.contrib.estimator.python.estimator import multi_head as multi_head_lib 26 from tensorflow.core.framework import summary_pb2 27 from tensorflow.python.estimator import model_fn 28 from tensorflow.python.estimator.canned import metric_keys 29 from tensorflow.python.estimator.canned import prediction_keys 30 from tensorflow.python.framework import constant_op 31 from tensorflow.python.framework import ops 32 from tensorflow.python.ops import string_ops 33 from tensorflow.python.platform import test 34 from tensorflow.python.saved_model import signature_constants 35 36 37 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 38 39 40 def _initialize_variables(test_case, scaffold): 41 scaffold.finalize() 42 test_case.assertIsNone(scaffold.init_feed_dict) 43 test_case.assertIsNone(scaffold.init_fn) 44 scaffold.init_op.run() 45 scaffold.ready_for_local_init_op.eval() 46 scaffold.local_init_op.run() 47 scaffold.ready_op.eval() 48 test_case.assertIsNotNone(scaffold.saver) 49 50 51 def _assert_simple_summaries(test_case, expected_summaries, summary_str, 52 tol=1e-6): 53 """Assert summary the specified simple values. 54 55 Args: 56 test_case: test case. 57 expected_summaries: Dict of expected tags and simple values. 58 summary_str: Serialized `summary_pb2.Summary`. 59 tol: Tolerance for relative and absolute. 60 """ 61 summary = summary_pb2.Summary() 62 summary.ParseFromString(summary_str) 63 test_case.assertAllClose(expected_summaries, { 64 v.tag: v.simple_value for v in summary.value 65 }, rtol=tol, atol=tol) 66 67 68 def _assert_no_hooks(test_case, spec): 69 test_case.assertAllEqual([], spec.training_chief_hooks) 70 test_case.assertAllEqual([], spec.training_hooks) 71 72 73 def _sigmoid(logits): 74 return 1 / (1 + np.exp(-logits)) 75 76 77 class MultiHeadTest(test.TestCase): 78 79 def setUp(self): 80 ops.reset_default_graph() 81 82 def test_no_heads(self): 83 with self.assertRaisesRegexp( 84 ValueError, r'Must specify heads\. Given: \[\]'): 85 multi_head_lib.multi_head(heads=[]) 86 87 def test_head_name_missing(self): 88 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 89 head2 = head_lib.multi_label_head(n_classes=3) 90 with self.assertRaisesRegexp( 91 ValueError, r'All given heads must have name specified\.'): 92 multi_head_lib.multi_head([head1, head2]) 93 94 def test_head_weights_wrong_size(self): 95 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 96 head2 = head_lib.multi_label_head(n_classes=3, name='head2') 97 with self.assertRaisesRegexp( 98 ValueError, 99 r'heads and head_weights must have the same size\. ' 100 r'Given len\(heads\): 2. Given len\(head_weights\): 1\.'): 101 multi_head_lib.multi_head([head1, head2], head_weights=[1.]) 102 103 def test_name(self): 104 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 105 head2 = head_lib.multi_label_head(n_classes=3, name='head2') 106 multi_head = multi_head_lib.multi_head([head1, head2]) 107 self.assertEqual('head1_head2', multi_head.name) 108 109 def test_predict_two_heads_logits_dict(self): 110 """Tests predict with logits as dict.""" 111 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 112 head2 = head_lib.multi_label_head(n_classes=3, name='head2') 113 multi_head = multi_head_lib.multi_head([head1, head2]) 114 115 logits = { 116 'head1': np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32), 117 'head2': np.array([[2., -2., 2.], [-3., 2., -2.]], dtype=np.float32) 118 } 119 expected_probabilities = { 120 'head1': _sigmoid(logits['head1']), 121 'head2': _sigmoid(logits['head2']), 122 } 123 124 spec = multi_head.create_estimator_spec( 125 features={'x': np.array(((42,),), dtype=np.int32)}, 126 mode=model_fn.ModeKeys.PREDICT, 127 logits=logits) 128 129 self.assertItemsEqual( 130 (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', 131 'head2', 'classification/head2', 'predict/head2'), 132 spec.export_outputs.keys()) 133 134 # Assert predictions and export_outputs. 135 with self.test_session() as sess: 136 _initialize_variables(self, spec.scaffold) 137 self.assertIsNone(spec.scaffold.summary_op) 138 predictions = sess.run(spec.predictions) 139 self.assertAllClose( 140 logits['head1'], 141 predictions[('head1', prediction_keys.PredictionKeys.LOGITS)]) 142 self.assertAllClose( 143 logits['head2'], 144 predictions[('head2', prediction_keys.PredictionKeys.LOGITS)]) 145 self.assertAllClose( 146 expected_probabilities['head1'], 147 predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)]) 148 self.assertAllClose( 149 expected_probabilities['head2'], 150 predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)]) 151 152 self.assertAllClose( 153 expected_probabilities['head1'], 154 sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) 155 self.assertAllClose( 156 expected_probabilities['head1'], 157 sess.run(spec.export_outputs['head1'].scores)) 158 self.assertAllClose( 159 expected_probabilities['head2'], 160 sess.run(spec.export_outputs['head2'].scores)) 161 162 def test_predict_two_heads_logits_tensor(self): 163 """Tests predict with logits as Tensor.""" 164 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 165 head2 = head_lib.multi_label_head(n_classes=3, name='head2') 166 multi_head = multi_head_lib.multi_head([head1, head2]) 167 168 logits = np.array( 169 [[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]], dtype=np.float32) 170 expected_logits1 = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32) 171 expected_logits2 = np.array([[2., -2., 2.], [-3., 2., -2.]], 172 dtype=np.float32) 173 expected_probabilities = { 174 'head1': _sigmoid(expected_logits1), 175 'head2': _sigmoid(expected_logits2), 176 } 177 178 spec = multi_head.create_estimator_spec( 179 features={'x': np.array(((42,),), dtype=np.int32)}, 180 mode=model_fn.ModeKeys.PREDICT, 181 logits=logits) 182 183 self.assertItemsEqual( 184 (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1', 185 'head2', 'classification/head2', 'predict/head2'), 186 spec.export_outputs.keys()) 187 188 # Assert predictions and export_outputs. 189 with self.test_session() as sess: 190 _initialize_variables(self, spec.scaffold) 191 self.assertIsNone(spec.scaffold.summary_op) 192 predictions = sess.run(spec.predictions) 193 self.assertAllClose( 194 expected_logits1, 195 predictions[('head1', prediction_keys.PredictionKeys.LOGITS)]) 196 self.assertAllClose( 197 expected_logits2, 198 predictions[('head2', prediction_keys.PredictionKeys.LOGITS)]) 199 self.assertAllClose( 200 expected_probabilities['head1'], 201 predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)]) 202 self.assertAllClose( 203 expected_probabilities['head2'], 204 predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)]) 205 206 self.assertAllClose( 207 expected_probabilities['head1'], 208 sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores)) 209 self.assertAllClose( 210 expected_probabilities['head1'], 211 sess.run(spec.export_outputs['head1'].scores)) 212 self.assertAllClose( 213 expected_probabilities['head2'], 214 sess.run(spec.export_outputs['head2'].scores)) 215 216 def test_predict_two_heads_logits_tensor_multi_dim(self): 217 """Tests predict with multi-dimensional logits of shape [2, 2, 5].""" 218 head1 = head_lib.regression_head(label_dimension=2, name='head1') 219 head2 = head_lib.regression_head(label_dimension=3, name='head2') 220 multi_head = multi_head_lib.multi_head([head1, head2]) 221 222 logits = np.array( 223 [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]], 224 [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]], 225 dtype=np.float32) 226 expected_logits1 = np.array( 227 [[[-1., 1.], [-1., 1.]], 228 [[-1.5, 1.], [-1.5, 1.]]], 229 dtype=np.float32) 230 expected_logits2 = np.array( 231 [[[2., -2., 2.], [2., -2., 2.]], 232 [[-3., 2., -2.], [-3., 2., -2.]]], 233 dtype=np.float32) 234 235 spec = multi_head.create_estimator_spec( 236 features={'x': np.array(((42,),), dtype=np.int32)}, 237 mode=model_fn.ModeKeys.PREDICT, 238 logits=logits) 239 240 self.assertItemsEqual( 241 (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1', 242 'head2', 'regression/head2', 'predict/head2'), 243 spec.export_outputs.keys()) 244 245 # Assert predictions and export_outputs. 246 with self.test_session() as sess: 247 _initialize_variables(self, spec.scaffold) 248 self.assertIsNone(spec.scaffold.summary_op) 249 predictions = sess.run(spec.predictions) 250 self.assertAllClose( 251 expected_logits1, 252 predictions[('head1', prediction_keys.PredictionKeys.PREDICTIONS)]) 253 self.assertAllClose( 254 expected_logits2, 255 predictions[('head2', prediction_keys.PredictionKeys.PREDICTIONS)]) 256 257 self.assertAllClose( 258 expected_logits1, 259 sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].value)) 260 self.assertAllClose( 261 expected_logits1, 262 sess.run(spec.export_outputs['head1'].value)) 263 self.assertAllClose( 264 expected_logits2, 265 sess.run(spec.export_outputs['head2'].value)) 266 267 def test_eval_two_heads_with_weights(self): 268 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 269 head2 = head_lib.multi_label_head(n_classes=3, name='head2') 270 multi_head = multi_head_lib.multi_head( 271 [head1, head2], head_weights=[1., 2.]) 272 273 logits = { 274 'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), 275 'head2': np.array([[20., -20., 20.], [-30., 20., -20.]], 276 dtype=np.float32), 277 } 278 labels = { 279 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 280 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), 281 } 282 # For large logits, sigmoid cross entropy loss is approximated as: 283 # loss = labels * (logits < 0) * (-logits) + 284 # (1 - labels) * (logits > 0) * logits => 285 # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] 286 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] 287 # Average over classes, weighted sum over batch and heads. 288 expected_loss_head1 = 17.5 289 expected_loss_head2 = 30.0 290 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 291 292 spec = multi_head.create_estimator_spec( 293 features={'x': np.array(((42,),), dtype=np.int32)}, 294 mode=model_fn.ModeKeys.EVAL, 295 logits=logits, 296 labels=labels) 297 298 keys = metric_keys.MetricKeys 299 expected_metrics = { 300 keys.LOSS + '/head1': expected_loss_head1, 301 keys.LOSS + '/head2': expected_loss_head2, 302 # Average loss over examples. 303 keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, 304 keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, 305 # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but 306 # this assert tests that the algorithm remains consistent. 307 keys.AUC + '/head1': 0.1667, 308 keys.AUC + '/head2': 0.3333, 309 keys.AUC_PR + '/head1': 0.49999964, 310 keys.AUC_PR + '/head2': 0.33333313, 311 } 312 313 # Assert spec contains expected tensors. 314 self.assertIsNotNone(spec.loss) 315 self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys()) 316 self.assertIsNone(spec.train_op) 317 self.assertIsNone(spec.export_outputs) 318 _assert_no_hooks(self, spec) 319 320 # Assert predictions, loss, and metrics. 321 tol = 1e-3 322 with self.test_session() as sess: 323 _initialize_variables(self, spec.scaffold) 324 self.assertIsNone(spec.scaffold.summary_op) 325 value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops} 326 update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops} 327 loss, metrics = sess.run((spec.loss, update_ops)) 328 self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) 329 # Check results of both update (in `metrics`) and value ops. 330 self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol) 331 self.assertAllClose( 332 expected_metrics, {k: value_ops[k].eval() for k in value_ops}, 333 rtol=tol, 334 atol=tol) 335 336 def test_train_create_loss_one_head(self): 337 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 338 multi_head = multi_head_lib.multi_head([head1]) 339 340 logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} 341 labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} 342 loss = multi_head.create_loss( 343 features={'x': np.array(((42,),), dtype=np.int32)}, 344 mode=model_fn.ModeKeys.TRAIN, 345 logits=logits, 346 labels=labels)[0] 347 tol = 1e-3 348 with self.test_session(): 349 # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2] 350 # (averaged over classes, sum-reduced over examples). 351 self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol) 352 353 def test_train_create_loss_two_heads_with_weights(self): 354 # Use different example weighting for each head weighting. 355 weights1 = np.array([[1.], [2.]], dtype=np.float32) 356 weights2 = np.array([[2.], [3.]]) 357 head1 = head_lib.multi_label_head(n_classes=2, name='head1', 358 weight_column='weights1') 359 head2 = head_lib.multi_label_head(n_classes=3, name='head2', 360 weight_column='weights2') 361 multi_head = multi_head_lib.multi_head( 362 [head1, head2], head_weights=[1., 2.]) 363 364 logits = { 365 'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), 366 'head2': np.array([[20., -20., 20.], [-30., 20., -20.]], 367 dtype=np.float32), 368 } 369 labels = { 370 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 371 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), 372 } 373 training_loss, unreduced_losses, weights, _ = multi_head.create_loss( 374 features={ 375 'x': np.array(((42,),), dtype=np.int32), 376 'weights1': weights1, 377 'weights2': weights2 378 }, 379 mode=model_fn.ModeKeys.TRAIN, 380 logits=logits, 381 labels=labels) 382 tol = 1e-3 383 with self.test_session(): 384 # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] 385 # = [10, 7.5] 386 # training_loss = 1 * 10 + 2 * 7.5 = 25 387 # head-weighted unreduced_loss = 1 * [10, 7.5] 388 self.assertAllClose( 389 [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) 390 # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] 391 # = [20, 10] 392 # training_loss = 2 * 20 + 3 * 10 = 70 393 # head-weighted unreduced_loss = 2 * [20, 10] 394 self.assertAllClose( 395 [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) 396 # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 397 self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) 398 # head-weighted example weights 399 self.assertAllClose( 400 [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) 401 self.assertAllClose( 402 [[4.], [6.]], weights['head2'].eval(), rtol=tol, atol=tol) 403 404 def test_train_create_loss_logits_tensor(self): 405 """Tests create_loss with logits Tensor.""" 406 weights1 = np.array([[1.], [2.]], dtype=np.float32) 407 weights2 = np.array([[2.], [3.]]) 408 head1 = head_lib.multi_label_head(n_classes=2, name='head1', 409 weight_column='weights1') 410 head2 = head_lib.multi_label_head(n_classes=3, name='head2', 411 weight_column='weights2') 412 multi_head = multi_head_lib.multi_head( 413 [head1, head2], head_weights=[1., 2.]) 414 415 logits = np.array([[-10., 10., 20., -20., 20.], 416 [-15., 10., -30., 20., -20.]], dtype=np.float32) 417 labels = { 418 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 419 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), 420 } 421 training_loss, unreduced_losses, weights, _ = multi_head.create_loss( 422 features={ 423 'x': np.array(((42,),), dtype=np.int32), 424 'weights1': weights1, 425 'weights2': weights2 426 }, 427 mode=model_fn.ModeKeys.TRAIN, 428 logits=logits, 429 labels=labels) 430 tol = 1e-3 431 with self.test_session(): 432 # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]] 433 # = [10, 7.5] 434 # training_loss = 1 * 10 + 2 * 7.5 = 25 435 # head-weighted unreduced_loss = 1 * [10, 7.5] 436 self.assertAllClose( 437 [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol) 438 # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]] 439 # = [20, 10] 440 # training_loss = 2 * 20 + 3 * 10 = 70 441 # head-weighted unreduced_loss = 2 * [20, 10] 442 self.assertAllClose( 443 [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol) 444 # head-weighted training_loss = 1 * 25 + 2 * 70 = 165 445 self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol) 446 # head-weighted example weights 447 self.assertAllClose( 448 [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol) 449 self.assertAllClose( 450 [[4.], [6.]], weights['head2'].eval(), rtol=tol, atol=tol) 451 452 def test_train_create_loss_logits_tensor_multi_dim(self): 453 """Tests create_loss with multi-dimensional logits of shape [2, 2, 5].""" 454 head1 = head_lib.regression_head(label_dimension=2, name='head1') 455 head2 = head_lib.regression_head(label_dimension=3, name='head2') 456 multi_head = multi_head_lib.multi_head([head1, head2]) 457 458 logits = np.array( 459 [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]], 460 [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]], 461 dtype=np.float32) 462 labels = { 463 'head1': np.array([[[1., 0.], [1., 0.]], 464 [[1.5, 1.5], [1.5, 1.5]]], dtype=np.float32), 465 'head2': np.array([[[0., 1., 0.], [0., 1., 0.]], 466 [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32), 467 } 468 # Loss for the first head: 469 # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 + 470 # (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2 471 # = 28 472 # Loss for the second head: 473 # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 + 474 # (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2 475 # = 74 476 expected_training_loss = 28. + 74. 477 478 training_loss = multi_head.create_loss( 479 features={}, 480 mode=model_fn.ModeKeys.TRAIN, 481 logits=logits, 482 labels=labels)[0] 483 tol = 1e-3 484 with self.test_session(): 485 self.assertAllClose( 486 expected_training_loss, training_loss.eval(), rtol=tol, atol=tol) 487 488 def test_train_one_head(self): 489 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 490 multi_head = multi_head_lib.multi_head([head1]) 491 492 logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)} 493 labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)} 494 # For large logits, sigmoid cross entropy loss is approximated as: 495 # loss = labels * (logits < 0) * (-logits) + 496 # (1 - labels) * (logits > 0) * logits => 497 # expected_unweighted_loss = [[10., 10.], [15., 0.]] 498 # Average over classes, sum over weights. 499 expected_loss = 17.5 500 expected_train_result = 'my_train_op' 501 def _train_op_fn(loss): 502 return string_ops.string_join( 503 [constant_op.constant(expected_train_result), 504 string_ops.as_string(loss, precision=3)]) 505 506 spec = multi_head.create_estimator_spec( 507 features={'x': np.array(((42,),), dtype=np.int32)}, 508 mode=model_fn.ModeKeys.TRAIN, 509 logits=logits, 510 labels=labels, 511 train_op_fn=_train_op_fn) 512 513 self.assertIsNotNone(spec.loss) 514 self.assertEqual({}, spec.eval_metric_ops) 515 self.assertIsNotNone(spec.train_op) 516 self.assertIsNone(spec.export_outputs) 517 _assert_no_hooks(self, spec) 518 519 # Assert predictions, loss, train_op, and summaries. 520 tol = 1e-3 521 with self.test_session() as sess: 522 _initialize_variables(self, spec.scaffold) 523 self.assertIsNotNone(spec.scaffold.summary_op) 524 loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, 525 spec.scaffold.summary_op)) 526 self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) 527 self.assertEqual( 528 six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), 529 train_result) 530 _assert_simple_summaries(self, { 531 metric_keys.MetricKeys.LOSS: expected_loss, 532 metric_keys.MetricKeys.LOSS + '/head1': expected_loss, 533 # Average loss over examples. 534 metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2, 535 }, summary_str, tol) 536 537 def test_train_two_heads_with_weights(self): 538 head1 = head_lib.multi_label_head(n_classes=2, name='head1') 539 head2 = head_lib.multi_label_head(n_classes=3, name='head2') 540 multi_head = multi_head_lib.multi_head( 541 [head1, head2], head_weights=[1., 2.]) 542 543 logits = { 544 'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32), 545 'head2': np.array([[20., -20., 20.], [-30., 20., -20.]], 546 dtype=np.float32), 547 } 548 labels = { 549 'head1': np.array([[1, 0], [1, 1]], dtype=np.int64), 550 'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64), 551 } 552 # For large logits, sigmoid cross entropy loss is approximated as: 553 # loss = labels * (logits < 0) * (-logits) + 554 # (1 - labels) * (logits > 0) * logits => 555 # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]] 556 # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]] 557 # Average over classes, weighted sum over batch and heads. 558 expected_loss_head1 = 17.5 559 expected_loss_head2 = 30.0 560 expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2 561 expected_train_result = 'my_train_op' 562 def _train_op_fn(loss): 563 return string_ops.string_join( 564 [constant_op.constant(expected_train_result), 565 string_ops.as_string(loss, precision=3)]) 566 567 spec = multi_head.create_estimator_spec( 568 features={'x': np.array(((42,),), dtype=np.int32)}, 569 mode=model_fn.ModeKeys.TRAIN, 570 logits=logits, 571 labels=labels, 572 train_op_fn=_train_op_fn) 573 574 self.assertIsNotNone(spec.loss) 575 self.assertEqual({}, spec.eval_metric_ops) 576 self.assertIsNotNone(spec.train_op) 577 self.assertIsNone(spec.export_outputs) 578 _assert_no_hooks(self, spec) 579 580 # Assert predictions, loss, train_op, and summaries. 581 tol = 1e-3 582 with self.test_session() as sess: 583 _initialize_variables(self, spec.scaffold) 584 self.assertIsNotNone(spec.scaffold.summary_op) 585 loss, train_result, summary_str = sess.run((spec.loss, spec.train_op, 586 spec.scaffold.summary_op)) 587 self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol) 588 self.assertEqual( 589 six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)), 590 train_result) 591 _assert_simple_summaries(self, { 592 metric_keys.MetricKeys.LOSS: expected_loss, 593 metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1, 594 metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2, 595 # Average loss over examples. 596 metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2, 597 metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2, 598 }, summary_str, tol) 599 600 601 if __name__ == '__main__': 602 test.main() 603