Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for head."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 import six
     23 
     24 from tensorflow.contrib.estimator.python.estimator import head as head_lib
     25 from tensorflow.contrib.estimator.python.estimator import multi_head as multi_head_lib
     26 from tensorflow.core.framework import summary_pb2
     27 from tensorflow.python.estimator import model_fn
     28 from tensorflow.python.estimator.canned import metric_keys
     29 from tensorflow.python.estimator.canned import prediction_keys
     30 from tensorflow.python.framework import constant_op
     31 from tensorflow.python.framework import ops
     32 from tensorflow.python.ops import string_ops
     33 from tensorflow.python.platform import test
     34 from tensorflow.python.saved_model import signature_constants
     35 
     36 
     37 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
     38 
     39 
     40 def _initialize_variables(test_case, scaffold):
     41   scaffold.finalize()
     42   test_case.assertIsNone(scaffold.init_feed_dict)
     43   test_case.assertIsNone(scaffold.init_fn)
     44   scaffold.init_op.run()
     45   scaffold.ready_for_local_init_op.eval()
     46   scaffold.local_init_op.run()
     47   scaffold.ready_op.eval()
     48   test_case.assertIsNotNone(scaffold.saver)
     49 
     50 
     51 def _assert_simple_summaries(test_case, expected_summaries, summary_str,
     52                              tol=1e-6):
     53   """Assert summary the specified simple values.
     54 
     55   Args:
     56     test_case: test case.
     57     expected_summaries: Dict of expected tags and simple values.
     58     summary_str: Serialized `summary_pb2.Summary`.
     59     tol: Tolerance for relative and absolute.
     60   """
     61   summary = summary_pb2.Summary()
     62   summary.ParseFromString(summary_str)
     63   test_case.assertAllClose(expected_summaries, {
     64       v.tag: v.simple_value for v in summary.value
     65   }, rtol=tol, atol=tol)
     66 
     67 
     68 def _assert_no_hooks(test_case, spec):
     69   test_case.assertAllEqual([], spec.training_chief_hooks)
     70   test_case.assertAllEqual([], spec.training_hooks)
     71 
     72 
     73 def _sigmoid(logits):
     74   return 1 / (1 + np.exp(-logits))
     75 
     76 
     77 class MultiHeadTest(test.TestCase):
     78 
     79   def setUp(self):
     80     ops.reset_default_graph()
     81 
     82   def test_no_heads(self):
     83     with self.assertRaisesRegexp(
     84         ValueError, r'Must specify heads\. Given: \[\]'):
     85       multi_head_lib.multi_head(heads=[])
     86 
     87   def test_head_name_missing(self):
     88     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
     89     head2 = head_lib.multi_label_head(n_classes=3)
     90     with self.assertRaisesRegexp(
     91         ValueError, r'All given heads must have name specified\.'):
     92       multi_head_lib.multi_head([head1, head2])
     93 
     94   def test_head_weights_wrong_size(self):
     95     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
     96     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
     97     with self.assertRaisesRegexp(
     98         ValueError,
     99         r'heads and head_weights must have the same size\. '
    100         r'Given len\(heads\): 2. Given len\(head_weights\): 1\.'):
    101       multi_head_lib.multi_head([head1, head2], head_weights=[1.])
    102 
    103   def test_name(self):
    104     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
    105     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
    106     multi_head = multi_head_lib.multi_head([head1, head2])
    107     self.assertEqual('head1_head2', multi_head.name)
    108 
    109   def test_predict_two_heads_logits_dict(self):
    110     """Tests predict with logits as dict."""
    111     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
    112     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
    113     multi_head = multi_head_lib.multi_head([head1, head2])
    114 
    115     logits = {
    116         'head1': np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32),
    117         'head2': np.array([[2., -2., 2.], [-3., 2., -2.]], dtype=np.float32)
    118     }
    119     expected_probabilities = {
    120         'head1': _sigmoid(logits['head1']),
    121         'head2': _sigmoid(logits['head2']),
    122     }
    123 
    124     spec = multi_head.create_estimator_spec(
    125         features={'x': np.array(((42,),), dtype=np.int32)},
    126         mode=model_fn.ModeKeys.PREDICT,
    127         logits=logits)
    128 
    129     self.assertItemsEqual(
    130         (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1',
    131          'head2', 'classification/head2', 'predict/head2'),
    132         spec.export_outputs.keys())
    133 
    134     # Assert predictions and export_outputs.
    135     with self.test_session() as sess:
    136       _initialize_variables(self, spec.scaffold)
    137       self.assertIsNone(spec.scaffold.summary_op)
    138       predictions = sess.run(spec.predictions)
    139       self.assertAllClose(
    140           logits['head1'],
    141           predictions[('head1', prediction_keys.PredictionKeys.LOGITS)])
    142       self.assertAllClose(
    143           logits['head2'],
    144           predictions[('head2', prediction_keys.PredictionKeys.LOGITS)])
    145       self.assertAllClose(
    146           expected_probabilities['head1'],
    147           predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)])
    148       self.assertAllClose(
    149           expected_probabilities['head2'],
    150           predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)])
    151 
    152       self.assertAllClose(
    153           expected_probabilities['head1'],
    154           sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
    155       self.assertAllClose(
    156           expected_probabilities['head1'],
    157           sess.run(spec.export_outputs['head1'].scores))
    158       self.assertAllClose(
    159           expected_probabilities['head2'],
    160           sess.run(spec.export_outputs['head2'].scores))
    161 
    162   def test_predict_two_heads_logits_tensor(self):
    163     """Tests predict with logits as Tensor."""
    164     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
    165     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
    166     multi_head = multi_head_lib.multi_head([head1, head2])
    167 
    168     logits = np.array(
    169         [[-1., 1., 2., -2., 2.], [-1.5, 1., -3., 2., -2.]], dtype=np.float32)
    170     expected_logits1 = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
    171     expected_logits2 = np.array([[2., -2., 2.], [-3., 2., -2.]],
    172                                 dtype=np.float32)
    173     expected_probabilities = {
    174         'head1': _sigmoid(expected_logits1),
    175         'head2': _sigmoid(expected_logits2),
    176     }
    177 
    178     spec = multi_head.create_estimator_spec(
    179         features={'x': np.array(((42,),), dtype=np.int32)},
    180         mode=model_fn.ModeKeys.PREDICT,
    181         logits=logits)
    182 
    183     self.assertItemsEqual(
    184         (_DEFAULT_SERVING_KEY, 'head1', 'classification/head1', 'predict/head1',
    185          'head2', 'classification/head2', 'predict/head2'),
    186         spec.export_outputs.keys())
    187 
    188     # Assert predictions and export_outputs.
    189     with self.test_session() as sess:
    190       _initialize_variables(self, spec.scaffold)
    191       self.assertIsNone(spec.scaffold.summary_op)
    192       predictions = sess.run(spec.predictions)
    193       self.assertAllClose(
    194           expected_logits1,
    195           predictions[('head1', prediction_keys.PredictionKeys.LOGITS)])
    196       self.assertAllClose(
    197           expected_logits2,
    198           predictions[('head2', prediction_keys.PredictionKeys.LOGITS)])
    199       self.assertAllClose(
    200           expected_probabilities['head1'],
    201           predictions[('head1', prediction_keys.PredictionKeys.PROBABILITIES)])
    202       self.assertAllClose(
    203           expected_probabilities['head2'],
    204           predictions[('head2', prediction_keys.PredictionKeys.PROBABILITIES)])
    205 
    206       self.assertAllClose(
    207           expected_probabilities['head1'],
    208           sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
    209       self.assertAllClose(
    210           expected_probabilities['head1'],
    211           sess.run(spec.export_outputs['head1'].scores))
    212       self.assertAllClose(
    213           expected_probabilities['head2'],
    214           sess.run(spec.export_outputs['head2'].scores))
    215 
    216   def test_predict_two_heads_logits_tensor_multi_dim(self):
    217     """Tests predict with multi-dimensional logits of shape [2, 2, 5]."""
    218     head1 = head_lib.regression_head(label_dimension=2, name='head1')
    219     head2 = head_lib.regression_head(label_dimension=3, name='head2')
    220     multi_head = multi_head_lib.multi_head([head1, head2])
    221 
    222     logits = np.array(
    223         [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],
    224          [[-1.5, 1., -3., 2., -2.], [-1.5, 1., -3., 2., -2.]]],
    225         dtype=np.float32)
    226     expected_logits1 = np.array(
    227         [[[-1., 1.], [-1., 1.]],
    228          [[-1.5, 1.], [-1.5, 1.]]],
    229         dtype=np.float32)
    230     expected_logits2 = np.array(
    231         [[[2., -2., 2.], [2., -2., 2.]],
    232          [[-3., 2., -2.], [-3., 2., -2.]]],
    233         dtype=np.float32)
    234 
    235     spec = multi_head.create_estimator_spec(
    236         features={'x': np.array(((42,),), dtype=np.int32)},
    237         mode=model_fn.ModeKeys.PREDICT,
    238         logits=logits)
    239 
    240     self.assertItemsEqual(
    241         (_DEFAULT_SERVING_KEY, 'head1', 'regression/head1', 'predict/head1',
    242          'head2', 'regression/head2', 'predict/head2'),
    243         spec.export_outputs.keys())
    244 
    245     # Assert predictions and export_outputs.
    246     with self.test_session() as sess:
    247       _initialize_variables(self, spec.scaffold)
    248       self.assertIsNone(spec.scaffold.summary_op)
    249       predictions = sess.run(spec.predictions)
    250       self.assertAllClose(
    251           expected_logits1,
    252           predictions[('head1', prediction_keys.PredictionKeys.PREDICTIONS)])
    253       self.assertAllClose(
    254           expected_logits2,
    255           predictions[('head2', prediction_keys.PredictionKeys.PREDICTIONS)])
    256 
    257       self.assertAllClose(
    258           expected_logits1,
    259           sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].value))
    260       self.assertAllClose(
    261           expected_logits1,
    262           sess.run(spec.export_outputs['head1'].value))
    263       self.assertAllClose(
    264           expected_logits2,
    265           sess.run(spec.export_outputs['head2'].value))
    266 
    267   def test_eval_two_heads_with_weights(self):
    268     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
    269     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
    270     multi_head = multi_head_lib.multi_head(
    271         [head1, head2], head_weights=[1., 2.])
    272 
    273     logits = {
    274         'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
    275         'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],
    276                           dtype=np.float32),
    277     }
    278     labels = {
    279         'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
    280         'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
    281     }
    282     # For large logits, sigmoid cross entropy loss is approximated as:
    283     # loss = labels * (logits < 0) * (-logits) +
    284     #        (1 - labels) * (logits > 0) * logits =>
    285     # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
    286     # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
    287     # Average over classes, weighted sum over batch and heads.
    288     expected_loss_head1 = 17.5
    289     expected_loss_head2 = 30.0
    290     expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2
    291 
    292     spec = multi_head.create_estimator_spec(
    293         features={'x': np.array(((42,),), dtype=np.int32)},
    294         mode=model_fn.ModeKeys.EVAL,
    295         logits=logits,
    296         labels=labels)
    297 
    298     keys = metric_keys.MetricKeys
    299     expected_metrics = {
    300         keys.LOSS + '/head1': expected_loss_head1,
    301         keys.LOSS + '/head2': expected_loss_head2,
    302         # Average loss over examples.
    303         keys.LOSS_MEAN + '/head1': expected_loss_head1 / 2,
    304         keys.LOSS_MEAN + '/head2': expected_loss_head2 / 2,
    305         # auc and auc_pr cannot be reliably calculated for only 4-6 samples, but
    306         # this assert tests that the algorithm remains consistent.
    307         keys.AUC + '/head1': 0.1667,
    308         keys.AUC + '/head2': 0.3333,
    309         keys.AUC_PR + '/head1': 0.49999964,
    310         keys.AUC_PR + '/head2': 0.33333313,
    311     }
    312 
    313     # Assert spec contains expected tensors.
    314     self.assertIsNotNone(spec.loss)
    315     self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
    316     self.assertIsNone(spec.train_op)
    317     self.assertIsNone(spec.export_outputs)
    318     _assert_no_hooks(self, spec)
    319 
    320     # Assert predictions, loss, and metrics.
    321     tol = 1e-3
    322     with self.test_session() as sess:
    323       _initialize_variables(self, spec.scaffold)
    324       self.assertIsNone(spec.scaffold.summary_op)
    325       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
    326       update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
    327       loss, metrics = sess.run((spec.loss, update_ops))
    328       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
    329       # Check results of both update (in `metrics`) and value ops.
    330       self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
    331       self.assertAllClose(
    332           expected_metrics, {k: value_ops[k].eval() for k in value_ops},
    333           rtol=tol,
    334           atol=tol)
    335 
    336   def test_train_create_loss_one_head(self):
    337     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
    338     multi_head = multi_head_lib.multi_head([head1])
    339 
    340     logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)}
    341     labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}
    342     loss = multi_head.create_loss(
    343         features={'x': np.array(((42,),), dtype=np.int32)},
    344         mode=model_fn.ModeKeys.TRAIN,
    345         logits=logits,
    346         labels=labels)[0]
    347     tol = 1e-3
    348     with self.test_session():
    349       # Unreduced loss of the head is [[(10 + 10) / 2], (15 + 0) / 2]
    350       # (averaged over classes, sum-reduced over examples).
    351       self.assertAllClose(17.5, loss.eval(), rtol=tol, atol=tol)
    352 
    353   def test_train_create_loss_two_heads_with_weights(self):
    354     # Use different example weighting for each head weighting.
    355     weights1 = np.array([[1.], [2.]], dtype=np.float32)
    356     weights2 = np.array([[2.], [3.]])
    357     head1 = head_lib.multi_label_head(n_classes=2, name='head1',
    358                                       weight_column='weights1')
    359     head2 = head_lib.multi_label_head(n_classes=3, name='head2',
    360                                       weight_column='weights2')
    361     multi_head = multi_head_lib.multi_head(
    362         [head1, head2], head_weights=[1., 2.])
    363 
    364     logits = {
    365         'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
    366         'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],
    367                           dtype=np.float32),
    368     }
    369     labels = {
    370         'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
    371         'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
    372     }
    373     training_loss, unreduced_losses, weights, _ = multi_head.create_loss(
    374         features={
    375             'x': np.array(((42,),), dtype=np.int32),
    376             'weights1': weights1,
    377             'weights2': weights2
    378         },
    379         mode=model_fn.ModeKeys.TRAIN,
    380         logits=logits,
    381         labels=labels)
    382     tol = 1e-3
    383     with self.test_session():
    384       # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
    385       # = [10, 7.5]
    386       # training_loss = 1 * 10 + 2 * 7.5 = 25
    387       # head-weighted unreduced_loss = 1 * [10, 7.5]
    388       self.assertAllClose(
    389           [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol)
    390       # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
    391       # = [20, 10]
    392       # training_loss = 2 * 20 + 3 * 10 = 70
    393       # head-weighted unreduced_loss = 2 * [20, 10]
    394       self.assertAllClose(
    395           [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol)
    396       # head-weighted training_loss = 1 * 25 + 2 * 70 = 165
    397       self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol)
    398       # head-weighted example weights
    399       self.assertAllClose(
    400           [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol)
    401       self.assertAllClose(
    402           [[4.], [6.]], weights['head2'].eval(), rtol=tol, atol=tol)
    403 
    404   def test_train_create_loss_logits_tensor(self):
    405     """Tests create_loss with logits Tensor."""
    406     weights1 = np.array([[1.], [2.]], dtype=np.float32)
    407     weights2 = np.array([[2.], [3.]])
    408     head1 = head_lib.multi_label_head(n_classes=2, name='head1',
    409                                       weight_column='weights1')
    410     head2 = head_lib.multi_label_head(n_classes=3, name='head2',
    411                                       weight_column='weights2')
    412     multi_head = multi_head_lib.multi_head(
    413         [head1, head2], head_weights=[1., 2.])
    414 
    415     logits = np.array([[-10., 10., 20., -20., 20.],
    416                        [-15., 10., -30., 20., -20.]], dtype=np.float32)
    417     labels = {
    418         'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
    419         'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
    420     }
    421     training_loss, unreduced_losses, weights, _ = multi_head.create_loss(
    422         features={
    423             'x': np.array(((42,),), dtype=np.int32),
    424             'weights1': weights1,
    425             'weights2': weights2
    426         },
    427         mode=model_fn.ModeKeys.TRAIN,
    428         logits=logits,
    429         labels=labels)
    430     tol = 1e-3
    431     with self.test_session():
    432       # loss of the first head is [[(10 + 10) / 2], [(15 + 0) / 2]]
    433       # = [10, 7.5]
    434       # training_loss = 1 * 10 + 2 * 7.5 = 25
    435       # head-weighted unreduced_loss = 1 * [10, 7.5]
    436       self.assertAllClose(
    437           [[10.], [7.5]], unreduced_losses['head1'].eval(), rtol=tol, atol=tol)
    438       # loss of the second head is [[(20 + 20 + 20) / 3], [(30 + 0 + 0) / 3]]
    439       # = [20, 10]
    440       # training_loss = 2 * 20 + 3 * 10 = 70
    441       # head-weighted unreduced_loss = 2 * [20, 10]
    442       self.assertAllClose(
    443           [[40.], [20.]], unreduced_losses['head2'].eval(), rtol=tol, atol=tol)
    444       # head-weighted training_loss = 1 * 25 + 2 * 70 = 165
    445       self.assertAllClose(165, training_loss.eval(), rtol=tol, atol=tol)
    446       # head-weighted example weights
    447       self.assertAllClose(
    448           [[1.], [2.]], weights['head1'].eval(), rtol=tol, atol=tol)
    449       self.assertAllClose(
    450           [[4.], [6.]], weights['head2'].eval(), rtol=tol, atol=tol)
    451 
    452   def test_train_create_loss_logits_tensor_multi_dim(self):
    453     """Tests create_loss with multi-dimensional logits of shape [2, 2, 5]."""
    454     head1 = head_lib.regression_head(label_dimension=2, name='head1')
    455     head2 = head_lib.regression_head(label_dimension=3, name='head2')
    456     multi_head = multi_head_lib.multi_head([head1, head2])
    457 
    458     logits = np.array(
    459         [[[-1., 1., 2., -2., 2.], [-1., 1., 2., -2., 2.]],
    460          [[-1.5, 1.5, -2., 2., -2.], [-1.5, 1.5, -2., 2., -2.]]],
    461         dtype=np.float32)
    462     labels = {
    463         'head1': np.array([[[1., 0.], [1., 0.]],
    464                            [[1.5, 1.5], [1.5, 1.5]]], dtype=np.float32),
    465         'head2': np.array([[[0., 1., 0.], [0., 1., 0.]],
    466                            [[2., 2., 0.], [2., 2., 0.]]], dtype=np.float32),
    467     }
    468     # Loss for the first head:
    469     # loss1 = (1+1)^2 + (0-1)^2 + (1+1)^2 + (0-1)^2 +
    470     #         (1.5+1.5)^2 + (1.5-1.5)^2 + (1.5+1.5)^2 + (1.5-1.5)^2
    471     #       = 28
    472     # Loss for the second head:
    473     # loss2 = (0-2)^2 + (1+2)^2 + (0-2)^2 + (0-2)^2 + (1+2)^2 + (0-2)^2 +
    474     #         (2+2)^2 + (2-2)^2 + (0+2)^2 + (2+2)^2 + (2-2)^2 + (0+2)^2
    475     #       = 74
    476     expected_training_loss = 28. + 74.
    477 
    478     training_loss = multi_head.create_loss(
    479         features={},
    480         mode=model_fn.ModeKeys.TRAIN,
    481         logits=logits,
    482         labels=labels)[0]
    483     tol = 1e-3
    484     with self.test_session():
    485       self.assertAllClose(
    486           expected_training_loss, training_loss.eval(), rtol=tol, atol=tol)
    487 
    488   def test_train_one_head(self):
    489     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
    490     multi_head = multi_head_lib.multi_head([head1])
    491 
    492     logits = {'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)}
    493     labels = {'head1': np.array([[1, 0], [1, 1]], dtype=np.int64)}
    494     # For large logits, sigmoid cross entropy loss is approximated as:
    495     # loss = labels * (logits < 0) * (-logits) +
    496     #        (1 - labels) * (logits > 0) * logits =>
    497     # expected_unweighted_loss = [[10., 10.], [15., 0.]]
    498     # Average over classes, sum over weights.
    499     expected_loss = 17.5
    500     expected_train_result = 'my_train_op'
    501     def _train_op_fn(loss):
    502       return string_ops.string_join(
    503           [constant_op.constant(expected_train_result),
    504            string_ops.as_string(loss, precision=3)])
    505 
    506     spec = multi_head.create_estimator_spec(
    507         features={'x': np.array(((42,),), dtype=np.int32)},
    508         mode=model_fn.ModeKeys.TRAIN,
    509         logits=logits,
    510         labels=labels,
    511         train_op_fn=_train_op_fn)
    512 
    513     self.assertIsNotNone(spec.loss)
    514     self.assertEqual({}, spec.eval_metric_ops)
    515     self.assertIsNotNone(spec.train_op)
    516     self.assertIsNone(spec.export_outputs)
    517     _assert_no_hooks(self, spec)
    518 
    519     # Assert predictions, loss, train_op, and summaries.
    520     tol = 1e-3
    521     with self.test_session() as sess:
    522       _initialize_variables(self, spec.scaffold)
    523       self.assertIsNotNone(spec.scaffold.summary_op)
    524       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
    525                                                   spec.scaffold.summary_op))
    526       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
    527       self.assertEqual(
    528           six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
    529           train_result)
    530       _assert_simple_summaries(self, {
    531           metric_keys.MetricKeys.LOSS: expected_loss,
    532           metric_keys.MetricKeys.LOSS + '/head1': expected_loss,
    533           # Average loss over examples.
    534           metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss / 2,
    535       }, summary_str, tol)
    536 
    537   def test_train_two_heads_with_weights(self):
    538     head1 = head_lib.multi_label_head(n_classes=2, name='head1')
    539     head2 = head_lib.multi_label_head(n_classes=3, name='head2')
    540     multi_head = multi_head_lib.multi_head(
    541         [head1, head2], head_weights=[1., 2.])
    542 
    543     logits = {
    544         'head1': np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
    545         'head2': np.array([[20., -20., 20.], [-30., 20., -20.]],
    546                           dtype=np.float32),
    547     }
    548     labels = {
    549         'head1': np.array([[1, 0], [1, 1]], dtype=np.int64),
    550         'head2': np.array([[0, 1, 0], [1, 1, 0]], dtype=np.int64),
    551     }
    552     # For large logits, sigmoid cross entropy loss is approximated as:
    553     # loss = labels * (logits < 0) * (-logits) +
    554     #        (1 - labels) * (logits > 0) * logits =>
    555     # head1: expected_unweighted_loss = [[10., 10.], [15., 0.]]
    556     # head2: expected_unweighted_loss = [[20., 20., 20.], [30., 0., 0]]
    557     # Average over classes, weighted sum over batch and heads.
    558     expected_loss_head1 = 17.5
    559     expected_loss_head2 = 30.0
    560     expected_loss = 1. * expected_loss_head1 + 2. * expected_loss_head2
    561     expected_train_result = 'my_train_op'
    562     def _train_op_fn(loss):
    563       return string_ops.string_join(
    564           [constant_op.constant(expected_train_result),
    565            string_ops.as_string(loss, precision=3)])
    566 
    567     spec = multi_head.create_estimator_spec(
    568         features={'x': np.array(((42,),), dtype=np.int32)},
    569         mode=model_fn.ModeKeys.TRAIN,
    570         logits=logits,
    571         labels=labels,
    572         train_op_fn=_train_op_fn)
    573 
    574     self.assertIsNotNone(spec.loss)
    575     self.assertEqual({}, spec.eval_metric_ops)
    576     self.assertIsNotNone(spec.train_op)
    577     self.assertIsNone(spec.export_outputs)
    578     _assert_no_hooks(self, spec)
    579 
    580     # Assert predictions, loss, train_op, and summaries.
    581     tol = 1e-3
    582     with self.test_session() as sess:
    583       _initialize_variables(self, spec.scaffold)
    584       self.assertIsNotNone(spec.scaffold.summary_op)
    585       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
    586                                                   spec.scaffold.summary_op))
    587       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
    588       self.assertEqual(
    589           six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
    590           train_result)
    591       _assert_simple_summaries(self, {
    592           metric_keys.MetricKeys.LOSS: expected_loss,
    593           metric_keys.MetricKeys.LOSS + '/head1': expected_loss_head1,
    594           metric_keys.MetricKeys.LOSS + '/head2': expected_loss_head2,
    595           # Average loss over examples.
    596           metric_keys.MetricKeys.LOSS_MEAN + '/head1': expected_loss_head1 / 2,
    597           metric_keys.MetricKeys.LOSS_MEAN + '/head2': expected_loss_head2 / 2,
    598       }, summary_str, tol)
    599 
    600 
    601 if __name__ == '__main__':
    602   test.main()
    603