Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for head."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 import six
     23 
     24 from tensorflow.contrib.estimator.python.estimator import head as head_lib
     25 from tensorflow.core.framework import summary_pb2
     26 from tensorflow.python.estimator import model_fn
     27 from tensorflow.python.estimator.canned import metric_keys
     28 from tensorflow.python.estimator.canned import prediction_keys
     29 from tensorflow.python.framework import constant_op
     30 from tensorflow.python.framework import dtypes
     31 from tensorflow.python.framework import errors
     32 from tensorflow.python.framework import ops
     33 from tensorflow.python.framework import sparse_tensor
     34 from tensorflow.python.ops import array_ops
     35 from tensorflow.python.ops import control_flow_ops
     36 from tensorflow.python.ops import math_ops
     37 from tensorflow.python.ops import string_ops
     38 from tensorflow.python.ops.losses import losses
     39 from tensorflow.python.platform import test
     40 from tensorflow.python.saved_model import signature_constants
     41 from tensorflow.python.training import monitored_session
     42 
     43 
     44 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
     45 
     46 
     47 def _initialize_variables(test_case, scaffold):
     48   scaffold.finalize()
     49   test_case.assertIsNone(scaffold.init_feed_dict)
     50   test_case.assertIsNone(scaffold.init_fn)
     51   scaffold.init_op.run()
     52   scaffold.ready_for_local_init_op.eval()
     53   scaffold.local_init_op.run()
     54   scaffold.ready_op.eval()
     55   test_case.assertIsNotNone(scaffold.saver)
     56 
     57 
     58 def _assert_simple_summaries(test_case, expected_summaries, summary_str,
     59                              tol=1e-6):
     60   """Assert summary the specified simple values.
     61 
     62   Args:
     63     test_case: test case.
     64     expected_summaries: Dict of expected tags and simple values.
     65     summary_str: Serialized `summary_pb2.Summary`.
     66     tol: Tolerance for relative and absolute.
     67   """
     68   summary = summary_pb2.Summary()
     69   summary.ParseFromString(summary_str)
     70   test_case.assertAllClose(expected_summaries, {
     71       v.tag: v.simple_value for v in summary.value
     72   }, rtol=tol, atol=tol)
     73 
     74 
     75 def _assert_no_hooks(test_case, spec):
     76   test_case.assertAllEqual([], spec.training_chief_hooks)
     77   test_case.assertAllEqual([], spec.training_hooks)
     78 
     79 
     80 def _sigmoid(logits):
     81   return 1 / (1 + np.exp(-logits))
     82 
     83 
     84 def _sigmoid_cross_entropy(labels, logits):
     85   """Returns sigmoid cross entropy averaged over classes."""
     86   sigmoid_logits = _sigmoid(logits)
     87   unreduced_result = (
     88       -labels * np.log(sigmoid_logits)
     89       -(1 - labels) * np.log(1 - sigmoid_logits))
     90   # Mean over classes
     91   return np.mean(unreduced_result, axis=-1, keepdims=True)
     92 
     93 
     94 class MultiLabelHead(test.TestCase):
     95 
     96   def setUp(self):
     97     ops.reset_default_graph()
     98 
     99   def test_n_classes_is_none(self):
    100     with self.assertRaisesRegexp(
    101         ValueError,
    102         r'n_classes must be > 1 for multi-class classification\. Given: None'):
    103       head_lib.multi_label_head(n_classes=None)
    104 
    105   def test_n_classes_is_1(self):
    106     with self.assertRaisesRegexp(
    107         ValueError,
    108         r'n_classes must be > 1 for multi-class classification\. Given: 1'):
    109       head_lib.multi_label_head(n_classes=1)
    110 
    111   def test_threshold_too_small(self):
    112     with self.assertRaisesRegexp(
    113         ValueError,
    114         r'thresholds must be in \(0, 1\) range\. Given: 0\.0'):
    115       head_lib.multi_label_head(n_classes=2, thresholds=[0., 0.5])
    116 
    117   def test_threshold_too_large(self):
    118     with self.assertRaisesRegexp(
    119         ValueError,
    120         r'thresholds must be in \(0, 1\) range\. Given: 1\.0'):
    121       head_lib.multi_label_head(n_classes=2, thresholds=[0.5, 1.0])
    122 
    123   def test_label_vocabulary_dict(self):
    124     with self.assertRaisesRegexp(
    125         ValueError,
    126         r'label_vocabulary must be a list or tuple\. '
    127         r'Given type: <(type|class) \'dict\'>'):
    128       head_lib.multi_label_head(n_classes=2, label_vocabulary={'foo': 'bar'})
    129 
    130   def test_label_vocabulary_wrong_size(self):
    131     with self.assertRaisesRegexp(
    132         ValueError,
    133         r'Length of label_vocabulary must be n_classes \(3\). Given: 2'):
    134       head_lib.multi_label_head(n_classes=3, label_vocabulary=['foo', 'bar'])
    135 
    136   def test_invalid_loss_reduction(self):
    137     with self.assertRaisesRegexp(
    138         ValueError, r'Invalid loss_reduction: invalid_loss_reduction'):
    139       head_lib.multi_label_head(
    140           n_classes=3, loss_reduction='invalid_loss_reduction')
    141     with self.assertRaisesRegexp(
    142         ValueError, r'Invalid loss_reduction: none'):
    143       head_lib.multi_label_head(
    144           n_classes=3, loss_reduction=losses.Reduction.NONE)
    145 
    146   def test_loss_fn_arg_labels_missing(self):
    147     def _loss_fn(logits):
    148       del logits  # Unused
    149     with self.assertRaisesRegexp(
    150         ValueError,
    151         r'loss_fn must contain argument: labels\. '
    152         r'Given arguments: \(\'logits\',\)'):
    153       head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn)
    154 
    155   def test_loss_fn_arg_logits_missing(self):
    156     def _loss_fn(labels):
    157       del labels  # unused
    158     with self.assertRaisesRegexp(
    159         ValueError,
    160         r'loss_fn must contain argument: logits\. '
    161         r'Given arguments: \(\'labels\',\)'):
    162       head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn)
    163 
    164   def test_loss_fn_arg_features_ok(self):
    165     def _loss_fn(labels, logits, features):
    166       del labels, logits, features  # Unused
    167     head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn)
    168 
    169   def test_loss_fn_arg_invalid(self):
    170     def _loss_fn(labels, logits, name=None):
    171       del labels, logits, name  # Unused
    172     with self.assertRaisesRegexp(
    173         ValueError,
    174         r'loss_fn has unexpected args: \[\'name\'\]'):
    175       head_lib.multi_label_head(n_classes=3, loss_fn=_loss_fn)
    176 
    177   def test_name(self):
    178     head = head_lib.multi_label_head(n_classes=4, name='foo')
    179     self.assertEqual('foo', head.name)
    180 
    181   def test_predict(self):
    182     n_classes = 4
    183     head = head_lib.multi_label_head(n_classes)
    184     self.assertEqual(n_classes, head.logits_dimension)
    185 
    186     logits = np.array(
    187         [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
    188     expected_probabilities = _sigmoid(logits)
    189     expected_export_classes = [[b'0', b'1', b'2', b'3']] * 2
    190 
    191     spec = head.create_estimator_spec(
    192         features={'x': np.array(((42,),), dtype=np.int32)},
    193         mode=model_fn.ModeKeys.PREDICT,
    194         logits=logits)
    195 
    196     self.assertItemsEqual(
    197         (_DEFAULT_SERVING_KEY, 'predict', 'classification'),
    198         spec.export_outputs.keys())
    199 
    200     # Assert predictions and export_outputs.
    201     with self.test_session() as sess:
    202       _initialize_variables(self, spec.scaffold)
    203       self.assertIsNone(spec.scaffold.summary_op)
    204       predictions = sess.run(spec.predictions)
    205       self.assertAllClose(logits,
    206                           predictions[prediction_keys.PredictionKeys.LOGITS])
    207       self.assertAllClose(
    208           expected_probabilities,
    209           predictions[prediction_keys.PredictionKeys.PROBABILITIES])
    210 
    211       self.assertAllClose(
    212           expected_probabilities,
    213           sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].scores))
    214       self.assertAllEqual(
    215           expected_export_classes,
    216           sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes))
    217 
    218   def test_predict_with_label_vocabulary(self):
    219     n_classes = 4
    220     head = head_lib.multi_label_head(
    221         n_classes, label_vocabulary=['foo', 'bar', 'foobar', 'barfoo'])
    222 
    223     logits = np.array(
    224         [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
    225     expected_export_classes = [[b'foo', b'bar', b'foobar', b'barfoo']] * 2
    226 
    227     spec = head.create_estimator_spec(
    228         features={'x': np.array(((42,),), dtype=np.int32)},
    229         mode=model_fn.ModeKeys.PREDICT,
    230         logits=logits)
    231 
    232     with self.test_session() as sess:
    233       _initialize_variables(self, spec.scaffold)
    234       self.assertAllEqual(
    235           expected_export_classes,
    236           sess.run(spec.export_outputs[_DEFAULT_SERVING_KEY].classes))
    237 
    238   def test_weight_should_not_impact_prediction(self):
    239     n_classes = 4
    240     head = head_lib.multi_label_head(n_classes, weight_column='example_weights')
    241     self.assertEqual(n_classes, head.logits_dimension)
    242 
    243     logits = np.array(
    244         [[0., 1., 2., -1.], [-1., -2., -3., 1.]], dtype=np.float32)
    245     expected_probabilities = _sigmoid(logits)
    246 
    247     weights_2x1 = [[1.], [2.]]
    248     spec = head.create_estimator_spec(
    249         features={
    250             'x': np.array(((42,),), dtype=np.int32),
    251             'example_weights': weights_2x1,
    252         },
    253         mode=model_fn.ModeKeys.PREDICT,
    254         logits=logits)
    255 
    256     # Assert predictions and export_outputs.
    257     with self.test_session() as sess:
    258       _initialize_variables(self, spec.scaffold)
    259       self.assertIsNone(spec.scaffold.summary_op)
    260       predictions = sess.run(spec.predictions)
    261       self.assertAllClose(logits,
    262                           predictions[prediction_keys.PredictionKeys.LOGITS])
    263       self.assertAllClose(
    264           expected_probabilities,
    265           predictions[prediction_keys.PredictionKeys.PROBABILITIES])
    266 
    267   def test_eval_create_loss(self):
    268     """Tests head.create_loss for eval mode."""
    269     n_classes = 2
    270     head = head_lib.multi_label_head(n_classes)
    271 
    272     logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
    273     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    274     # loss = labels * -log(sigmoid(logits)) +
    275     #        (1 - labels) * -log(1 - sigmoid(logits))
    276     expected_training_loss = np.sum(
    277         _sigmoid_cross_entropy(labels=labels, logits=logits))
    278     actual_training_loss = head.create_loss(
    279         features={'x': np.array(((42,),), dtype=np.int32)},
    280         mode=model_fn.ModeKeys.EVAL,
    281         logits=logits,
    282         labels=labels)[0]
    283     with self.test_session():
    284       _initialize_variables(self, monitored_session.Scaffold())
    285       self.assertAllClose(expected_training_loss,
    286                           actual_training_loss.eval())
    287 
    288   def test_eval_create_loss_large_logits(self):
    289     """Tests head.create_loss for eval mode and large logits."""
    290     n_classes = 2
    291     head = head_lib.multi_label_head(n_classes)
    292 
    293     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    294     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    295     # loss = labels * -log(sigmoid(logits)) +
    296     #        (1 - labels) * -log(1 - sigmoid(logits))
    297     # For large logits, this is approximated as:
    298     # loss = labels * (logits < 0) * (-logits) +
    299     #        (1 - labels) * (logits > 0) * logits
    300     expected_training_loss = np.sum(
    301         np.array([[(10. + 10.) / 2.], [(15. + 0.) / 2.]], dtype=np.float32))
    302     actual_training_loss = head.create_loss(
    303         features={'x': np.array(((42,),), dtype=np.int32)},
    304         mode=model_fn.ModeKeys.EVAL,
    305         logits=logits,
    306         labels=labels)[0]
    307     with self.test_session():
    308       _initialize_variables(self, monitored_session.Scaffold())
    309       self.assertAllClose(
    310           expected_training_loss, actual_training_loss.eval(), atol=1e-4)
    311 
    312   def test_eval_create_loss_labels_wrong_shape(self):
    313     """Tests head.create_loss for eval mode when labels has the wrong shape."""
    314     n_classes = 2
    315     head = head_lib.multi_label_head(n_classes)
    316 
    317     logits = np.array([[-1., 1.], [-1.5, 1.]], dtype=np.float32)
    318     labels_placeholder = array_ops.placeholder(dtype=dtypes.int64)
    319     actual_training_loss = head.create_loss(
    320         features={'x': np.array(((42,),), dtype=np.int32)},
    321         mode=model_fn.ModeKeys.EVAL,
    322         logits=logits,
    323         labels=labels_placeholder)[0]
    324     with self.test_session():
    325       _initialize_variables(self, monitored_session.Scaffold())
    326       with self.assertRaisesRegexp(
    327           errors.InvalidArgumentError,
    328           r'\[expected_labels_shape: \] \[2 2\] \[labels_shape: \] \[2 1\]'):
    329         actual_training_loss.eval({
    330             labels_placeholder: np.array([[1], [1]], dtype=np.int64)
    331         })
    332       with self.assertRaisesRegexp(
    333           errors.InvalidArgumentError,
    334           r'labels shape must be \[D0, D1, ... DN, 2\]\..*'
    335           r'\[Received shape: \] \[2\]'):
    336         actual_training_loss.eval({
    337             labels_placeholder: np.array([1, 1], dtype=np.int64)
    338         })
    339 
    340   def test_eval_create_loss_loss_fn(self):
    341     """Tests head.create_loss for eval mode and custom loss_fn."""
    342     loss = np.array([[1.], [2.]], dtype=np.float32)
    343     logits_input = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    344     labels_input = np.array([[1, 0], [1, 1]], dtype=np.int64)
    345     def _loss_fn(labels, logits):
    346       check_labels = control_flow_ops.Assert(
    347           math_ops.reduce_all(math_ops.equal(labels, labels_input)),
    348           data=[labels])
    349       check_logits = control_flow_ops.Assert(
    350           math_ops.reduce_all(math_ops.equal(logits, logits_input)),
    351           data=[logits])
    352       with ops.control_dependencies([check_labels, check_logits]):
    353         return constant_op.constant(loss)
    354     head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn)
    355 
    356     actual_training_loss = head.create_loss(
    357         features={'x': np.array(((42,),), dtype=np.int32)},
    358         mode=model_fn.ModeKeys.EVAL,
    359         logits=logits_input,
    360         labels=labels_input)[0]
    361     with self.test_session():
    362       _initialize_variables(self, monitored_session.Scaffold())
    363       self.assertAllClose(np.sum(loss), actual_training_loss.eval())
    364 
    365   def test_eval_create_loss_loss_fn_wrong_shape(self):
    366     """Tests custom loss_fn that returns Tensor of unexpected shape."""
    367     loss = np.array([1., 2.], dtype=np.float32)
    368     def _loss_fn(labels, logits):
    369       del labels, logits  # Unused
    370       return constant_op.constant(loss)
    371     head = head_lib.multi_label_head(n_classes=2, loss_fn=_loss_fn)
    372 
    373     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    374     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    375     actual_training_loss = head.create_loss(
    376         features={'x': np.array(((42,),), dtype=np.int32)},
    377         mode=model_fn.ModeKeys.EVAL,
    378         logits=logits,
    379         labels=labels)[0]
    380     with self.test_session():
    381       _initialize_variables(self, monitored_session.Scaffold())
    382       with self.assertRaisesRegexp(
    383           errors.InvalidArgumentError,
    384           r'\[loss_fn must return Tensor of shape \[D0, D1, ... DN, 1\]\. \] '
    385           r'\[logits_shape: \] \[2 2\] \[loss_shape: \] \[2\]'):
    386         actual_training_loss.eval()
    387 
    388   def test_eval_labels_none(self):
    389     """Tests that error is raised when labels is None."""
    390     head = head_lib.multi_label_head(n_classes=2)
    391 
    392     with self.assertRaisesRegexp(
    393         ValueError, r'You must provide a labels Tensor\. Given: None\.'):
    394       head.create_estimator_spec(
    395           features={'x': np.array(((42,),), dtype=np.int32)},
    396           mode=model_fn.ModeKeys.EVAL,
    397           logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
    398           labels=None)
    399 
    400   def _test_eval(
    401       self, head, logits, labels, expected_loss, expected_metrics,
    402       features=None, regularization_losses=None):
    403     spec = head.create_estimator_spec(
    404         features=features or {},
    405         mode=model_fn.ModeKeys.EVAL,
    406         logits=logits,
    407         labels=labels,
    408         regularization_losses=regularization_losses)
    409 
    410     # Assert spec contains expected tensors.
    411     self.assertIsNotNone(spec.loss)
    412     self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
    413     self.assertIsNone(spec.train_op)
    414     self.assertIsNone(spec.export_outputs)
    415     _assert_no_hooks(self, spec)
    416 
    417     # Assert predictions, loss, and metrics.
    418     tol = 1e-3
    419     with self.test_session() as sess:
    420       _initialize_variables(self, spec.scaffold)
    421       self.assertIsNone(spec.scaffold.summary_op)
    422       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
    423       update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
    424       loss, metrics = sess.run((spec.loss, update_ops))
    425       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
    426       # Check results of both update (in `metrics`) and value ops.
    427       self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
    428       self.assertAllClose(
    429           expected_metrics, {k: value_ops[k].eval() for k in value_ops},
    430           rtol=tol,
    431           atol=tol)
    432 
    433   def test_eval(self):
    434     n_classes = 2
    435     head = head_lib.multi_label_head(n_classes)
    436     logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
    437     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    438     # loss = labels * -log(sigmoid(logits)) +
    439     #        (1 - labels) * -log(1 - sigmoid(logits))
    440     # Sum over examples.
    441     expected_loss = np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits))
    442     keys = metric_keys.MetricKeys
    443     expected_metrics = {
    444         # Average loss over examples.
    445         keys.LOSS_MEAN: expected_loss / 2,
    446         # auc and auc_pr cannot be reliably calculated for only 4 samples, but
    447         # this assert tests that the algorithm remains consistent.
    448         keys.AUC: 0.3333,
    449         keys.AUC_PR: 0.5972,
    450     }
    451     self._test_eval(
    452         head=head,
    453         logits=logits,
    454         labels=labels,
    455         expected_loss=expected_loss,
    456         expected_metrics=expected_metrics)
    457 
    458   def test_eval_sparse_labels(self):
    459     n_classes = 2
    460     head = head_lib.multi_label_head(n_classes)
    461     logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
    462     # Equivalent to multi_hot = [[1, 0], [1, 1]]
    463     labels = sparse_tensor.SparseTensor(
    464         values=[0, 0, 1],
    465         indices=[[0, 0], [1, 0], [1, 1]],
    466         dense_shape=[2, 2])
    467     labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
    468     # loss = labels * -log(sigmoid(logits)) +
    469     #        (1 - labels) * -log(1 - sigmoid(logits))
    470     # Sum over examples.
    471     expected_loss = (
    472         np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
    473     )
    474     keys = metric_keys.MetricKeys
    475     expected_metrics = {
    476         # Average loss over examples.
    477         keys.LOSS_MEAN: expected_loss / 2,
    478         # auc and auc_pr cannot be reliably calculated for only 4 samples, but
    479         # this assert tests that the algorithm remains consistent.
    480         keys.AUC: 0.3333,
    481         keys.AUC_PR: 0.5972,
    482     }
    483     self._test_eval(
    484         head=head,
    485         logits=logits,
    486         labels=labels,
    487         expected_loss=expected_loss,
    488         expected_metrics=expected_metrics)
    489 
    490   def test_eval_with_regularization_losses(self):
    491     n_classes = 2
    492     head = head_lib.multi_label_head(
    493         n_classes, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
    494     logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
    495     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    496     regularization_losses = [1.5, 0.5]
    497     expected_regularization_loss = 2.
    498     # unregularized_loss = sum(
    499     #     labels * -log(sigmoid(logits)) +
    500     #     (1 - labels) * -log(1 - sigmoid(logits))) / batch_size
    501     expected_unregularized_loss = np.sum(
    502         _sigmoid_cross_entropy(labels=labels, logits=logits)) / 2.
    503     expected_regularized_loss = (
    504         expected_unregularized_loss + expected_regularization_loss)
    505     keys = metric_keys.MetricKeys
    506     expected_metrics = {
    507         keys.LOSS_MEAN: expected_unregularized_loss,
    508         keys.LOSS_REGULARIZATION: expected_regularization_loss,
    509         # auc and auc_pr cannot be reliably calculated for only 4 samples, but
    510         # this assert tests that the algorithm remains consistent.
    511         keys.AUC: 0.3333,
    512         keys.AUC_PR: 0.5972,
    513     }
    514     self._test_eval(
    515         head=head,
    516         logits=logits,
    517         labels=labels,
    518         expected_loss=expected_regularized_loss,
    519         expected_metrics=expected_metrics,
    520         regularization_losses=regularization_losses)
    521 
    522   def test_eval_with_label_vocabulary(self):
    523     n_classes = 2
    524     head = head_lib.multi_label_head(
    525         n_classes, label_vocabulary=['class0', 'class1'])
    526     logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
    527     # Equivalent to multi_hot = [[1, 0], [1, 1]]
    528     labels = sparse_tensor.SparseTensor(
    529         values=['class0', 'class0', 'class1'],
    530         indices=[[0, 0], [1, 0], [1, 1]],
    531         dense_shape=[2, 2])
    532     labels_multi_hot = np.array([[1, 0], [1, 1]], dtype=np.int64)
    533     # loss = labels * -log(sigmoid(logits)) +
    534     #        (1 - labels) * -log(1 - sigmoid(logits))
    535     # Sum over examples.
    536     expected_loss = (
    537         np.sum(_sigmoid_cross_entropy(labels=labels_multi_hot, logits=logits))
    538     )
    539     keys = metric_keys.MetricKeys
    540     expected_metrics = {
    541         # Average loss over examples.
    542         keys.LOSS_MEAN: expected_loss / 2,
    543         # auc and auc_pr cannot be reliably calculated for only 4 samples, but
    544         # this assert tests that the algorithm remains consistent.
    545         keys.AUC: 0.3333,
    546         keys.AUC_PR: 0.5972,
    547     }
    548     self._test_eval(
    549         head=head,
    550         logits=logits,
    551         labels=labels,
    552         expected_loss=expected_loss,
    553         expected_metrics=expected_metrics)
    554 
    555   def test_eval_with_thresholds(self):
    556     n_classes = 2
    557     thresholds = [0.25, 0.5, 0.75]
    558     head = head_lib.multi_label_head(n_classes, thresholds=thresholds)
    559 
    560     logits = np.array([[-1., 1.], [-1.5, 1.5]], dtype=np.float32)
    561     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    562     # loss = labels * -log(sigmoid(logits)) +
    563     #        (1 - labels) * -log(1 - sigmoid(logits))
    564     # Sum over examples.
    565     expected_loss = (
    566         np.sum(_sigmoid_cross_entropy(labels=labels, logits=logits))
    567     )
    568 
    569     keys = metric_keys.MetricKeys
    570     expected_metrics = {
    571         # Average loss over examples.
    572         keys.LOSS_MEAN: expected_loss / 2,
    573         # auc and auc_pr cannot be reliably calculated for only 4 samples, but
    574         # this assert tests that the algorithm remains consistent.
    575         keys.AUC: 0.3333,
    576         keys.AUC_PR: 0.5972,
    577         keys.ACCURACY_AT_THRESHOLD % thresholds[0]: 2. / 4.,
    578         keys.PRECISION_AT_THRESHOLD % thresholds[0]: 2. / 3.,
    579         keys.RECALL_AT_THRESHOLD % thresholds[0]: 2. / 3.,
    580         keys.ACCURACY_AT_THRESHOLD % thresholds[1]: 1. / 4.,
    581         keys.PRECISION_AT_THRESHOLD % thresholds[1]: 1. / 2.,
    582         keys.RECALL_AT_THRESHOLD % thresholds[1]: 1. / 3.,
    583         keys.ACCURACY_AT_THRESHOLD % thresholds[2]: 2. / 4.,
    584         keys.PRECISION_AT_THRESHOLD % thresholds[2]: 1. / 1.,
    585         keys.RECALL_AT_THRESHOLD % thresholds[2]: 1. / 3.,
    586     }
    587 
    588     self._test_eval(
    589         head=head,
    590         logits=logits,
    591         labels=labels,
    592         expected_loss=expected_loss,
    593         expected_metrics=expected_metrics)
    594 
    595   def test_eval_with_weights(self):
    596     n_classes = 2
    597     head = head_lib.multi_label_head(n_classes, weight_column='example_weights')
    598 
    599     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    600     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    601     # For large logits, sigmoid cross entropy loss is approximated as:
    602     # loss = labels * (logits < 0) * (-logits) +
    603     #        (1 - labels) * (logits > 0) * logits =>
    604     # expected_unweighted_loss = [[10., 10.], [15., 0.]]
    605     # Average over classes, weighted sum over examples.
    606     expected_loss = 25.
    607 
    608     spec = head.create_estimator_spec(
    609         features={
    610             'x': np.array([[41], [42]], dtype=np.int32),
    611             'example_weights': np.array([[1.], [2.]], dtype=np.float32),
    612         },
    613         mode=model_fn.ModeKeys.EVAL,
    614         logits=logits,
    615         labels=labels)
    616 
    617     keys = metric_keys.MetricKeys
    618     expected_metrics = {
    619         # Average loss over weighted examples.
    620         keys.LOSS_MEAN: expected_loss / 3,
    621         # auc and auc_pr cannot be reliably calculated for only 4 samples, but
    622         # this assert tests that the algorithm remains consistent.
    623         keys.AUC: 0.2000,
    624         keys.AUC_PR: 0.5833,
    625     }
    626 
    627     # Assert spec contains expected tensors.
    628     self.assertIsNotNone(spec.loss)
    629     self.assertItemsEqual(expected_metrics.keys(), spec.eval_metric_ops.keys())
    630     self.assertIsNone(spec.train_op)
    631     self.assertIsNone(spec.export_outputs)
    632     _assert_no_hooks(self, spec)
    633 
    634     # Assert predictions, loss, and metrics.
    635     tol = 1e-3
    636     with self.test_session() as sess:
    637       _initialize_variables(self, spec.scaffold)
    638       self.assertIsNone(spec.scaffold.summary_op)
    639       value_ops = {k: spec.eval_metric_ops[k][0] for k in spec.eval_metric_ops}
    640       update_ops = {k: spec.eval_metric_ops[k][1] for k in spec.eval_metric_ops}
    641       loss, metrics = sess.run((spec.loss, update_ops))
    642       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
    643       # Check results of both update (in `metrics`) and value ops.
    644       self.assertAllClose(expected_metrics, metrics, rtol=tol, atol=tol)
    645       self.assertAllClose(
    646           expected_metrics, {k: value_ops[k].eval() for k in value_ops},
    647           rtol=tol,
    648           atol=tol)
    649 
    650   def test_train_create_loss_large_logits(self):
    651     """Tests head.create_loss for train mode and large logits."""
    652     n_classes = 2
    653     head = head_lib.multi_label_head(n_classes, weight_column='example_weights')
    654 
    655     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    656     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    657     weights = np.array([[1.], [2.]], dtype=np.float32)
    658     # loss = labels * -log(sigmoid(logits)) +
    659     #        (1 - labels) * -log(1 - sigmoid(logits))
    660     # For large logits, this is approximated as:
    661     # loss = labels * (logits < 0) * (-logits) +
    662     #        (1 - labels) * (logits > 0) * logits
    663     expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]]
    664     expected_weights = [[1.], [2.]]
    665     expected_training_loss = 1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.
    666     training_loss, unreduced_loss, actual_weights, _ = head.create_loss(
    667         features={
    668             'x': np.array(((42,),), dtype=np.int32),
    669             'example_weights': weights
    670         },
    671         mode=model_fn.ModeKeys.TRAIN,
    672         logits=logits,
    673         labels=labels)
    674     with self.test_session():
    675       _initialize_variables(self, monitored_session.Scaffold())
    676       self.assertAllClose(
    677           expected_training_loss, training_loss.eval(), atol=1e-4)
    678       self.assertAllClose(
    679           expected_unreduced_loss, unreduced_loss.eval(), atol=1e-4)
    680       self.assertAllClose(expected_weights, actual_weights.eval())
    681 
    682   def test_train_create_loss_loss_reduction(self):
    683     """Tests head.create_loss with loss_reduction."""
    684     n_classes = 2
    685     head = head_lib.multi_label_head(
    686         n_classes, weight_column='example_weights',
    687         loss_reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS)
    688 
    689     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    690     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    691     weights = np.array([[1.], [2.]], dtype=np.float32)
    692     # loss = labels * -log(sigmoid(logits)) +
    693     #        (1 - labels) * -log(1 - sigmoid(logits))
    694     # For large logits, this is approximated as:
    695     # loss = labels * (logits < 0) * (-logits) +
    696     #        (1 - labels) * (logits > 0) * logits
    697     expected_unreduced_loss = [[(10. + 10.) / 2.], [(15. + 0.) / 2.]]
    698     expected_weights = [[1.], [2.]]
    699     expected_training_loss = (1. * (10. + 10.) / 2. + 2. * (15. + 0.) / 2.) / 2.
    700     training_loss, unreduced_loss, actual_weights, _ = head.create_loss(
    701         features={
    702             'x': np.array(((42,),), dtype=np.int32),
    703             'example_weights': weights
    704         },
    705         mode=model_fn.ModeKeys.TRAIN,
    706         logits=logits,
    707         labels=labels)
    708     with self.test_session():
    709       _initialize_variables(self, monitored_session.Scaffold())
    710       self.assertAllClose(
    711           expected_training_loss, training_loss.eval(), atol=1e-4)
    712       self.assertAllClose(
    713           expected_unreduced_loss, unreduced_loss.eval(), atol=1e-4)
    714       self.assertAllClose(expected_weights, actual_weights.eval())
    715 
    716   def test_train_labels_none(self):
    717     """Tests that error is raised when labels is None."""
    718     head = head_lib.multi_label_head(n_classes=2)
    719     def _no_op_train_fn(loss):
    720       del loss
    721       return control_flow_ops.no_op()
    722 
    723     with self.assertRaisesRegexp(
    724         ValueError, r'You must provide a labels Tensor\. Given: None\.'):
    725       head.create_estimator_spec(
    726           features={'x': np.array(((42,),), dtype=np.int32)},
    727           mode=model_fn.ModeKeys.TRAIN,
    728           logits=np.array([[-10., 10.], [-15., 10.]], dtype=np.float32),
    729           labels=None,
    730           train_op_fn=_no_op_train_fn)
    731 
    732   def test_train_invalid_indicator_labels(self):
    733     head = head_lib.multi_label_head(n_classes=2)
    734     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    735     # The value 2 is outside the allowed range.
    736     labels = np.array([[2, 0], [1, 1]], dtype=np.int64)
    737     def _train_op_fn(loss):
    738       del loss
    739       return control_flow_ops.no_op()
    740 
    741     spec = head.create_estimator_spec(
    742         features={},
    743         mode=model_fn.ModeKeys.TRAIN,
    744         logits=logits,
    745         labels=labels,
    746         train_op_fn=_train_op_fn)
    747     with self.test_session() as sess:
    748       _initialize_variables(self, spec.scaffold)
    749       with self.assertRaisesRegexp(
    750           errors.InvalidArgumentError,
    751           r'labels must be an integer indicator Tensor with values in '
    752           r'\[0, 1\]'):
    753         sess.run(spec.loss)
    754 
    755   def test_train_invalid_sparse_labels(self):
    756     head = head_lib.multi_label_head(n_classes=2)
    757     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    758     # The value 2 is outside the allowed range.
    759     labels = sparse_tensor.SparseTensor(
    760         values=[2, 0, 1],
    761         indices=[[0, 0], [1, 0], [1, 1]],
    762         dense_shape=[2, 2])
    763     def _train_op_fn(loss):
    764       del loss
    765       return control_flow_ops.no_op()
    766 
    767     spec = head.create_estimator_spec(
    768         features={},
    769         mode=model_fn.ModeKeys.TRAIN,
    770         logits=logits,
    771         labels=labels,
    772         train_op_fn=_train_op_fn)
    773     with self.test_session() as sess:
    774       _initialize_variables(self, spec.scaffold)
    775       with self.assertRaisesRegexp(
    776           errors.InvalidArgumentError,
    777           r'labels must be an integer SparseTensor with values in \[0, 2\)'):
    778         sess.run(spec.loss)
    779 
    780   def _test_train(self, head, logits, labels, expected_loss):
    781     expected_train_result = 'my_train_op'
    782     def _train_op_fn(loss):
    783       return string_ops.string_join(
    784           [constant_op.constant(expected_train_result),
    785            string_ops.as_string(loss, precision=3)])
    786 
    787     spec = head.create_estimator_spec(
    788         features={'x': np.array(((42,),), dtype=np.int32)},
    789         mode=model_fn.ModeKeys.TRAIN,
    790         logits=logits,
    791         labels=labels,
    792         train_op_fn=_train_op_fn)
    793 
    794     self.assertIsNotNone(spec.loss)
    795     self.assertEqual({}, spec.eval_metric_ops)
    796     self.assertIsNotNone(spec.train_op)
    797     self.assertIsNone(spec.export_outputs)
    798     _assert_no_hooks(self, spec)
    799 
    800     # Assert predictions, loss, train_op, and summaries.
    801     tol = 1e-3
    802     with self.test_session() as sess:
    803       _initialize_variables(self, spec.scaffold)
    804       self.assertIsNotNone(spec.scaffold.summary_op)
    805       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
    806                                                   spec.scaffold.summary_op))
    807       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
    808       self.assertEqual(
    809           six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
    810           train_result)
    811       _assert_simple_summaries(self, {
    812           metric_keys.MetricKeys.LOSS: expected_loss,
    813           # Average loss over examples.
    814           metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 2,
    815       }, summary_str, tol)
    816 
    817   def test_train(self):
    818     head = head_lib.multi_label_head(n_classes=2)
    819     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    820     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    821     # For large logits, sigmoid cross entropy loss is approximated as:
    822     # loss = labels * (logits < 0) * (-logits) +
    823     #        (1 - labels) * (logits > 0) * logits =>
    824     # expected_unweighted_loss = [[10., 10.], [15., 0.]]
    825     # Average over classes, sum over weights.
    826     expected_loss = 17.5
    827     self._test_train(
    828         head=head, logits=logits, labels=labels, expected_loss=expected_loss)
    829 
    830   def test_train_sparse_labels(self):
    831     head = head_lib.multi_label_head(n_classes=2)
    832     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    833     # Equivalent to multi_hot = [[1, 0], [1, 1]]
    834     labels = sparse_tensor.SparseTensor(
    835         values=[0, 0, 1],
    836         indices=[[0, 0], [1, 0], [1, 1]],
    837         dense_shape=[2, 2])
    838     # For large logits, sigmoid cross entropy loss is approximated as:
    839     # loss = labels * (logits < 0) * (-logits) +
    840     #        (1 - labels) * (logits > 0) * logits =>
    841     # expected_unweighted_loss = [[10., 10.], [15., 0.]]
    842     # Average over classes, sum over weights.
    843     expected_loss = 17.5
    844     self._test_train(
    845         head=head, logits=logits, labels=labels, expected_loss=expected_loss)
    846 
    847   def test_train_with_label_vocabulary(self):
    848     head = head_lib.multi_label_head(
    849         n_classes=2, label_vocabulary=['class0', 'class1'])
    850     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    851     # Equivalent to multi_hot = [[1, 0], [1, 1]]
    852     labels = sparse_tensor.SparseTensor(
    853         values=['class0', 'class0', 'class1'],
    854         indices=[[0, 0], [1, 0], [1, 1]],
    855         dense_shape=[2, 2])
    856     # For large logits, sigmoid cross entropy loss is approximated as:
    857     # loss = labels * (logits < 0) * (-logits) +
    858     #        (1 - labels) * (logits > 0) * logits =>
    859     # expected_unweighted_loss = [[10., 10.], [15., 0.]]
    860     # Average over classes, sum over weights.
    861     expected_loss = 17.5
    862     self._test_train(
    863         head=head, logits=logits, labels=labels, expected_loss=expected_loss)
    864 
    865   def test_train_with_regularization_losses(self):
    866     head = head_lib.multi_label_head(
    867         n_classes=2, loss_reduction=losses.Reduction.SUM_OVER_BATCH_SIZE)
    868     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    869     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    870     regularization_losses = [1.5, 0.5]
    871     # For large logits, sigmoid cross entropy loss is approximated as:
    872     # loss = labels * (logits < 0) * (-logits) +
    873     #        (1 - labels) * (logits > 0) * logits =>
    874     # expected_unweighted_loss = [[10., 10.], [15., 0.]]
    875     # Average over classes and over batch and add regularization loss.
    876     expected_loss = 35. / 4. + 2.
    877     expected_summaries = {
    878         metric_keys.MetricKeys.LOSS: expected_loss,
    879         metric_keys.MetricKeys.LOSS_REGULARIZATION: 2.,
    880     }
    881     expected_train_result = 'my_train_op'
    882     def _train_op_fn(loss):
    883       return string_ops.string_join(
    884           [constant_op.constant(expected_train_result),
    885            string_ops.as_string(loss, precision=3)])
    886 
    887     spec = head.create_estimator_spec(
    888         features={'x': np.array(((42,),), dtype=np.int32)},
    889         mode=model_fn.ModeKeys.TRAIN,
    890         logits=logits,
    891         labels=labels,
    892         train_op_fn=_train_op_fn,
    893         regularization_losses=regularization_losses)
    894 
    895     # Assert predictions, loss, train_op, and summaries.
    896     tol = 1e-3
    897     with self.test_session() as sess:
    898       _initialize_variables(self, spec.scaffold)
    899       self.assertIsNotNone(spec.scaffold.summary_op)
    900       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
    901                                                   spec.scaffold.summary_op))
    902       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
    903       self.assertEqual(
    904           six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
    905           train_result)
    906       _assert_simple_summaries(self, expected_summaries, summary_str, tol)
    907 
    908   def test_train_with_weights(self):
    909     n_classes = 2
    910     head = head_lib.multi_label_head(n_classes, weight_column='example_weights')
    911 
    912     logits = np.array([[-10., 10.], [-15., 10.]], dtype=np.float32)
    913     labels = np.array([[1, 0], [1, 1]], dtype=np.int64)
    914     # For large logits, sigmoid cross entropy loss is approximated as:
    915     # loss = labels * (logits < 0) * (-logits) +
    916     #        (1 - labels) * (logits > 0) * logits =>
    917     # expected_unweighted_loss = [[10., 10.], [15., 0.]]
    918     # Average over classes, weighted sum over examples.
    919     expected_loss = 25.
    920     expected_train_result = 'my_train_op'
    921     def _train_op_fn(loss):
    922       return string_ops.string_join(
    923           [constant_op.constant(expected_train_result),
    924            string_ops.as_string(loss, precision=3)])
    925 
    926     spec = head.create_estimator_spec(
    927         features={
    928             'x': np.array([[41], [42]], dtype=np.int32),
    929             'example_weights': np.array([[1.], [2.]], dtype=np.float32),
    930         },
    931         mode=model_fn.ModeKeys.TRAIN,
    932         logits=logits,
    933         labels=labels,
    934         train_op_fn=_train_op_fn)
    935 
    936     self.assertIsNotNone(spec.loss)
    937     self.assertEqual({}, spec.eval_metric_ops)
    938     self.assertIsNotNone(spec.train_op)
    939     self.assertIsNone(spec.export_outputs)
    940     _assert_no_hooks(self, spec)
    941 
    942     # Assert predictions, loss, train_op, and summaries.
    943     tol = 1e-3
    944     with self.test_session() as sess:
    945       _initialize_variables(self, spec.scaffold)
    946       self.assertIsNotNone(spec.scaffold.summary_op)
    947       loss, train_result, summary_str = sess.run((spec.loss, spec.train_op,
    948                                                   spec.scaffold.summary_op))
    949       self.assertAllClose(expected_loss, loss, rtol=tol, atol=tol)
    950       self.assertEqual(
    951           six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
    952           train_result)
    953       _assert_simple_summaries(self, {
    954           metric_keys.MetricKeys.LOSS: expected_loss,
    955           # Average loss over weighted examples.
    956           metric_keys.MetricKeys.LOSS_MEAN: expected_loss / 3,
    957       }, summary_str, tol)
    958 
    959   def test_multi_dim_weighted_train_create_loss(self):
    960     """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
    961     head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
    962 
    963     logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
    964                        [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
    965     labels = np.array([[[1, 0, 0], [1, 0, 0]],
    966                        [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
    967     weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
    968     # unreduced_loss =
    969     #     [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
    970     #   = [[20/3, 10/3], [4, 8]]
    971     expected_unreduced_loss = [[[20./3.], [10./3.]], [[4.], [8.]]]
    972     # weights are reshaped to [2, 2, 1] to match logits.
    973     expected_weights = [[[1.], [1.5]], [[2.], [2.5]]]
    974     # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
    975     expected_training_loss = 39.6667
    976     training_loss, unreduced_loss, actual_weights, _ = head.create_loss(
    977         features={'weights': weights},
    978         mode=model_fn.ModeKeys.TRAIN,
    979         logits=logits,
    980         labels=labels)
    981     atol = 1.e-3
    982     with self.test_session():
    983       _initialize_variables(self, monitored_session.Scaffold())
    984       self.assertAllClose(
    985           expected_training_loss, training_loss.eval(), atol=atol)
    986       self.assertAllClose(
    987           expected_unreduced_loss, unreduced_loss.eval(), atol=atol)
    988       self.assertAllClose(expected_weights, actual_weights.eval())
    989 
    990   def test_multi_dim_weighted_train(self):
    991     """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
    992     head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
    993 
    994     logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
    995                        [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
    996     labels = np.array([[[1, 0, 0], [1, 0, 0]],
    997                        [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
    998     weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
    999     # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
   1000     #      = [[20/3, 10/3], [4, 8]]
   1001     # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
   1002     expected_loss = 39.6667
   1003     expected_train_result = 'my_train_op'
   1004     def _train_op_fn(loss):
   1005       return string_ops.string_join(
   1006           [constant_op.constant(expected_train_result),
   1007            string_ops.as_string(loss, precision=3)])
   1008 
   1009     spec = head.create_estimator_spec(
   1010         features={'weights': weights},
   1011         mode=model_fn.ModeKeys.TRAIN,
   1012         logits=logits,
   1013         labels=labels,
   1014         train_op_fn=_train_op_fn)
   1015 
   1016     atol = 1.e-3
   1017     with self.test_session() as sess:
   1018       _initialize_variables(self, monitored_session.Scaffold())
   1019       loss, train_result = sess.run((spec.loss, spec.train_op))
   1020       self.assertAllClose(expected_loss, loss, atol=atol)
   1021       self.assertEqual(
   1022           six.b('{0:s}{1:.3f}'.format(expected_train_result, expected_loss)),
   1023           train_result)
   1024 
   1025   def test_multi_dim_weights_wrong_inner_dim(self):
   1026     """Logits and labels of shape [2, 2, 3], weights [2, 1]."""
   1027     head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
   1028 
   1029     logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
   1030                        [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
   1031     labels = np.array([[[1, 0, 0], [1, 0, 0]],
   1032                        [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
   1033     weights = np.array([[1.], [2.]], dtype=np.float32)
   1034     def _train_op_fn(loss):
   1035       del loss
   1036       return control_flow_ops.no_op()
   1037 
   1038     spec = head.create_estimator_spec(
   1039         features={'weights': weights},
   1040         mode=model_fn.ModeKeys.TRAIN,
   1041         logits=logits,
   1042         labels=labels,
   1043         train_op_fn=_train_op_fn)
   1044     with self.test_session():
   1045       _initialize_variables(self, monitored_session.Scaffold())
   1046       with self.assertRaisesRegexp(
   1047           errors.InvalidArgumentError,
   1048           r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 1\]'):
   1049         spec.loss.eval()
   1050 
   1051   def test_multi_dim_weights_wrong_outer_dim(self):
   1052     """Logits and labels of shape [2, 2, 3], weights [2, 2, 3]."""
   1053     head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
   1054 
   1055     logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
   1056                        [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
   1057     labels = np.array([[[1, 0, 0], [1, 0, 0]],
   1058                        [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
   1059     weights = np.array([[[1., 1., 1.], [1.5, 1.5, 1.5]],
   1060                         [[2., 2., 2.], [2.5, 2.5, 2.5]]], dtype=np.float32)
   1061     weights_placeholder = array_ops.placeholder(dtype=dtypes.float32)
   1062     def _train_op_fn(loss):
   1063       del loss
   1064       return control_flow_ops.no_op()
   1065 
   1066     spec = head.create_estimator_spec(
   1067         features={'weights': weights_placeholder},
   1068         mode=model_fn.ModeKeys.TRAIN,
   1069         logits=logits,
   1070         labels=labels,
   1071         train_op_fn=_train_op_fn)
   1072     with self.test_session():
   1073       _initialize_variables(self, monitored_session.Scaffold())
   1074       with self.assertRaisesRegexp(
   1075           errors.InvalidArgumentError,
   1076           r'\[logits_shape: \] \[2 2 3\] \[weights_shape: \] \[2 2 3\]'):
   1077         spec.loss.eval({weights_placeholder: weights})
   1078 
   1079   def test_multi_dim_weighted_eval(self):
   1080     """Logits and labels of shape [2, 2, 3], weights [2, 2]."""
   1081     head = head_lib.multi_label_head(n_classes=3, weight_column='weights')
   1082 
   1083     logits = np.array([[[-10., 10., -10.], [10., -10., 10.]],
   1084                        [[-12., 12., -12.], [12., -12., 12.]]], dtype=np.float32)
   1085     labels = np.array([[[1, 0, 0], [1, 0, 0]],
   1086                        [[0, 1, 1], [0, 1, 1]]], dtype=np.int64)
   1087     weights = np.array([[1., 1.5], [2., 2.5]], dtype=np.float32)
   1088     # loss = [[10 + 10 + 0, 0 + 0 + 10], [0 + 0 + 12, 12 + 12 + 0]] / 3
   1089     #      = [[20/3, 10/3], [4, 8]]
   1090     # weighted_sum_loss = 1*20/3 + 1.5*10/3 + 2*4 + 2.5*8 = 39.6667
   1091     expected_loss = 39.6667
   1092     keys = metric_keys.MetricKeys
   1093     expected_metrics = {
   1094         keys.LOSS_MEAN: expected_loss / np.sum(weights),
   1095         # auc and auc_pr cannot be reliably calculated for only 4 samples, but
   1096         # this assert tests that the algorithm remains consistent.
   1097         keys.AUC: 0.4977,
   1098         keys.AUC_PR: 0.4037,
   1099     }
   1100     self._test_eval(
   1101         head=head,
   1102         features={'weights': weights},
   1103         logits=logits,
   1104         labels=labels,
   1105         expected_loss=expected_loss,
   1106         expected_metrics=expected_metrics)
   1107 
   1108 
   1109 if __name__ == '__main__':
   1110   test.main()
   1111