Home | History | Annotate | Download | only in estimator
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Abstractions for the head(s) of a model."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 from tensorflow.python.estimator import model_fn
     22 from tensorflow.python.estimator.canned import head as head_lib
     23 from tensorflow.python.estimator.canned import metric_keys
     24 from tensorflow.python.estimator.canned import prediction_keys
     25 from tensorflow.python.estimator.export import export_output
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.framework import sparse_tensor
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import check_ops
     31 from tensorflow.python.ops import lookup_ops
     32 from tensorflow.python.ops import math_ops
     33 from tensorflow.python.ops import metrics as metrics_lib
     34 from tensorflow.python.ops import sparse_ops
     35 from tensorflow.python.ops.losses import losses
     36 from tensorflow.python.saved_model import signature_constants
     37 from tensorflow.python.summary import summary
     38 
     39 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY
     40 
     41 
     42 def multi_class_head(n_classes,
     43                      weight_column=None,
     44                      label_vocabulary=None,
     45                      loss_reduction=losses.Reduction.SUM,
     46                      loss_fn=None,
     47                      name=None):
     48   """Creates a `_Head` for multi class classification.
     49 
     50   Uses `sparse_softmax_cross_entropy` loss.
     51 
     52   The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`.
     53   In many applications, the shape is `[batch_size, n_classes]`.
     54 
     55   `labels` must be a dense `Tensor` with shape matching `logits`, namely
     56   `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
     57   `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
     58   `labels` must be an integer `Tensor` with values specifying the class index.
     59 
     60   If `weight_column` is specified, weights must be of shape
     61   `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
     62 
     63   The loss is the weighted sum over the input dimensions. Namely, if the input
     64   labels have shape `[batch_size, 1]`, the loss is the weighted sum over
     65   `batch_size`.
     66 
     67   Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
     68   `(labels, logits, features)` as arguments and returns unreduced loss with
     69   shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with
     70   shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
     71   the input labels before passing them to `loss_fn`.
     72 
     73   Args:
     74     n_classes: Number of classes, must be greater than 2 (for 2 classes, use
     75       `binary_classification_head`).
     76     weight_column: A string or a `_NumericColumn` created by
     77       `tf.feature_column.numeric_column` defining feature column representing
     78       weights. It is used to down weight or boost examples during training. It
     79       will be multiplied by the loss of the example.
     80     label_vocabulary: A list or tuple of strings representing possible label
     81       values. If it is not given, that means labels are already encoded as an
     82       integer within [0, n_classes). If given, labels must be of string type and
     83       have any value in `label_vocabulary`. Note that errors will be raised if
     84       `label_vocabulary` is not provided but labels are strings.
     85     loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
     86       reduce training loss over batch. Defaults to `SUM`.
     87     loss_fn: Optional loss function.
     88     name: name of the head. If provided, summary and metrics keys will be
     89       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
     90 
     91   Returns:
     92     An instance of `_Head` for multi class classification.
     93 
     94   Raises:
     95     ValueError: if `n_classes`, `label_vocabulary` or `loss_reduction` is
     96       invalid.
     97   """
     98   return head_lib._multi_class_head_with_softmax_cross_entropy_loss(  # pylint:disable=protected-access
     99       n_classes=n_classes,
    100       weight_column=weight_column,
    101       label_vocabulary=label_vocabulary,
    102       loss_reduction=loss_reduction,
    103       loss_fn=loss_fn,
    104       name=name)
    105 
    106 
    107 def binary_classification_head(
    108     weight_column=None,
    109     thresholds=None,
    110     label_vocabulary=None,
    111     loss_reduction=losses.Reduction.SUM,
    112     loss_fn=None,
    113     name=None):
    114   """Creates a `_Head` for single label binary classification.
    115 
    116   This head uses `sigmoid_cross_entropy_with_logits` loss.
    117 
    118   The head expects `logits` with shape `[D0, D1, ... DN, 1]`.
    119   In many applications, the shape is `[batch_size, 1]`.
    120 
    121   `labels` must be a dense `Tensor` with shape matching `logits`, namely
    122   `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string
    123   `Tensor` with values from the vocabulary. If `label_vocabulary` is not given,
    124   `labels` must be float `Tensor` with values in the interval `[0, 1]`.
    125 
    126   If `weight_column` is specified, weights must be of shape
    127   `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
    128 
    129   The loss is the weighted sum over the input dimensions. Namely, if the input
    130   labels have shape `[batch_size, 1]`, the loss is the weighted sum over
    131   `batch_size`.
    132 
    133   Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
    134   `(labels, logits, features)` as arguments and returns unreduced loss with
    135   shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with
    136   shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to
    137   the input labels before passing them to `loss_fn`.
    138 
    139   Args:
    140     weight_column: A string or a `_NumericColumn` created by
    141       `tf.feature_column.numeric_column` defining feature column representing
    142       weights. It is used to down weight or boost examples during training. It
    143       will be multiplied by the loss of the example.
    144     thresholds: Iterable of floats in the range `(0, 1)`. For binary
    145       classification metrics such as precision and recall, an eval metric is
    146       generated for each threshold value. This threshold is applied to the
    147       logistic values to determine the binary classification (i.e., above the
    148       threshold is `true`, below is `false`.
    149     label_vocabulary: A list or tuple of strings representing possible label
    150       values. If it is not given, labels must be float with values within
    151       [0, 1]. If given, labels must be string type and have any value in
    152       `label_vocabulary`. Note that errors will be raised if `label_vocabulary`
    153       is not provided but labels are strings.
    154     loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
    155       reduce training loss over batch. Defaults to `SUM`.
    156     loss_fn: Optional loss function.
    157     name: name of the head. If provided, summary and metrics keys will be
    158       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
    159 
    160   Returns:
    161     An instance of `_Head` for binary classification.
    162 
    163   Raises:
    164     ValueError: If `thresholds` contains a value outside of `(0, 1)`.
    165     ValueError: If `loss_reduction` is invalid.
    166   """
    167   return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss(  # pylint:disable=protected-access
    168       weight_column=weight_column,
    169       thresholds=thresholds,
    170       label_vocabulary=label_vocabulary,
    171       loss_reduction=loss_reduction,
    172       loss_fn=loss_fn,
    173       name=name)
    174 
    175 
    176 def regression_head(weight_column=None,
    177                     label_dimension=1,
    178                     loss_reduction=losses.Reduction.SUM,
    179                     loss_fn=None,
    180                     name=None):
    181   """Creates a `_Head` for regression using the `mean_squared_error` loss.
    182 
    183   The loss is the weighted sum over all input dimensions. Namely, if the input
    184   labels have shape `[batch_size, label_dimension]`, the loss is the weighted
    185   sum over both `batch_size` and `label_dimension`.
    186 
    187   The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`.
    188   In many applications, the shape is `[batch_size, label_dimension]`.
    189 
    190   The `labels` shape must match `logits`, namely
    191   `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape
    192   `[D0, D1, ... DN]` is also supported.
    193 
    194   If `weight_column` is specified, weights must be of shape
    195   `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or
    196   `[D0, D1, ... DN, label_dimension]`.
    197 
    198   Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
    199   `(labels, logits, features)` as arguments and returns unreduced loss with
    200   shape `[D0, D1, ... DN, label_dimension]`.
    201 
    202   Args:
    203     weight_column: A string or a `_NumericColumn` created by
    204       `tf.feature_column.numeric_column` defining feature column representing
    205       weights. It is used to down weight or boost examples during training. It
    206       will be multiplied by the loss of the example.
    207     label_dimension: Number of regression labels per example. This is the size
    208       of the last dimension of the labels `Tensor` (typically, this has shape
    209       `[batch_size, label_dimension]`).
    210     loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
    211       reduce training loss over batch. Defaults to `SUM`.
    212     loss_fn: Optional loss function.
    213     name: name of the head. If provided, summary and metrics keys will be
    214       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
    215 
    216   Returns:
    217     An instance of `_Head` for linear regression.
    218 
    219   Raises:
    220     ValueError: If `label_dimension` or `loss_reduction` is invalid.
    221   """
    222   return head_lib._regression_head_with_mean_squared_error_loss(  # pylint:disable=protected-access
    223       weight_column=weight_column,
    224       label_dimension=label_dimension,
    225       loss_reduction=loss_reduction,
    226       loss_fn=loss_fn,
    227       name=name)
    228 
    229 
    230 def multi_label_head(n_classes,
    231                      weight_column=None,
    232                      thresholds=None,
    233                      label_vocabulary=None,
    234                      loss_reduction=losses.Reduction.SUM,
    235                      loss_fn=None,
    236                      name=None):
    237   """Creates a `_Head` for multi-label classification.
    238 
    239   Multi-label classification handles the case where each example may have zero
    240   or more associated labels, from a discrete set. This is distinct from
    241   `multi_class_head` which has exactly one label per example.
    242 
    243   Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over
    244   the batch. Namely, if the input logits have shape `[batch_size, n_classes]`,
    245   the loss is the average over `n_classes` and the weighted sum over
    246   `batch_size`.
    247 
    248   The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many
    249   applications, the shape is `[batch_size, n_classes]`.
    250 
    251   Labels can be:
    252   * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]`
    253   * An integer `SparseTensor` of class indices. The `dense_shape` must be
    254     `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`.
    255   * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape`
    256     must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`.
    257 
    258   If `weight_column` is specified, weights must be of shape
    259   `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`.
    260 
    261   Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or
    262   `(labels, logits, features)` as arguments and returns unreduced loss with
    263   shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with
    264   shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies
    265   `label_vocabulary` to the input labels before passing them to `loss_fn`.
    266 
    267   Args:
    268     n_classes: Number of classes, must be greater than 1 (for 1 class, use
    269       `binary_classification_head`).
    270     weight_column: A string or a `_NumericColumn` created by
    271       `tf.feature_column.numeric_column` defining feature column representing
    272       weights. It is used to down weight or boost examples during training. It
    273       will be multiplied by the loss of the example.  Per-class weighting is
    274       not supported.
    275     thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision
    276       and recall metrics are evaluated for each threshold value. The threshold
    277       is applied to the predicted probabilities, i.e. above the threshold is
    278       `true`, below is `false`.
    279     label_vocabulary: A list of strings represents possible label values. If it
    280       is not given, that means labels are already encoded as integer within
    281       [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor
    282       string type and have any value in `label_vocabulary`. Also there will be
    283       errors if vocabulary is not provided and labels are string.
    284     loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to
    285       reduce training loss over batch. Defaults to `SUM`.
    286     loss_fn: Optional loss function.
    287     name: name of the head. If provided, summary and metrics keys will be
    288       suffixed by `"/" + name`. Also used as `name_scope` when creating ops.
    289 
    290   Returns:
    291     An instance of `_Head` for multi-label classification.
    292 
    293   Raises:
    294     ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is
    295     invalid.
    296   """
    297   thresholds = tuple(thresholds) if thresholds else tuple()
    298   if n_classes is None or n_classes < 2:
    299     raise ValueError(
    300         'n_classes must be > 1 for multi-class classification. '
    301         'Given: {}'.format(n_classes))
    302   for threshold in thresholds:
    303     if (threshold <= 0.0) or (threshold >= 1.0):
    304       raise ValueError(
    305           'thresholds must be in (0, 1) range. Given: {}'.format(threshold))
    306   if label_vocabulary is not None:
    307     if not isinstance(label_vocabulary, (list, tuple)):
    308       raise ValueError(
    309           'label_vocabulary must be a list or tuple. '
    310           'Given type: {}'.format(type(label_vocabulary)))
    311     if len(label_vocabulary) != n_classes:
    312       raise ValueError(
    313           'Length of label_vocabulary must be n_classes ({}). '
    314           'Given: {}'.format(n_classes, len(label_vocabulary)))
    315   if loss_fn:
    316     head_lib._validate_loss_fn_args(loss_fn)  # pylint:disable=protected-access
    317   if (loss_reduction not in losses.Reduction.all() or
    318       loss_reduction == losses.Reduction.NONE):
    319     raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction))
    320   return _MultiLabelHead(
    321       n_classes=n_classes, weight_column=weight_column, thresholds=thresholds,
    322       label_vocabulary=label_vocabulary, loss_reduction=loss_reduction,
    323       loss_fn=loss_fn, name=name)
    324 
    325 
    326 class _MultiLabelHead(head_lib._Head):  # pylint:disable=protected-access
    327   """`_Head` for multi-label classification."""
    328 
    329   def __init__(self,
    330                n_classes,
    331                weight_column=None,
    332                thresholds=None,
    333                label_vocabulary=None,
    334                loss_reduction=losses.Reduction.SUM,
    335                loss_fn=None,
    336                name=None):
    337     self._n_classes = n_classes
    338     self._weight_column = weight_column
    339     self._thresholds = thresholds
    340     self._label_vocabulary = label_vocabulary
    341     self._loss_reduction = loss_reduction
    342     self._loss_fn = loss_fn
    343     self._name = name
    344 
    345   @property
    346   def name(self):
    347     return self._name
    348 
    349   @property
    350   def logits_dimension(self):
    351     return self._n_classes
    352 
    353   def _process_labels(self, labels):
    354     if labels is None:
    355       raise ValueError(
    356           'You must provide a labels Tensor. Given: None. '
    357           'Suggested troubleshooting steps: Check that your data contain '
    358           'your label feature. Check that your input_fn properly parses and '
    359           'returns labels.')
    360     if isinstance(labels, sparse_tensor.SparseTensor):
    361       if labels.dtype == dtypes.string:
    362         label_ids_values = lookup_ops.index_table_from_tensor(
    363             vocabulary_list=tuple(self._label_vocabulary),
    364             name='class_id_lookup').lookup(labels.values)
    365         label_ids = sparse_tensor.SparseTensor(
    366             indices=labels.indices,
    367             values=label_ids_values,
    368             dense_shape=labels.dense_shape)
    369         return math_ops.to_int64(
    370             sparse_ops.sparse_to_indicator(label_ids, self._n_classes))
    371       else:
    372         err_msg = (
    373             r'labels must be an integer SparseTensor with values in '
    374             r'[0, {})'.format(self._n_classes))
    375         assert_int = check_ops.assert_integer(
    376             labels.values, message=err_msg)
    377         assert_less = check_ops.assert_less(
    378             labels.values,
    379             ops.convert_to_tensor(self._n_classes, dtype=labels.dtype),
    380             message=err_msg)
    381         assert_greater = check_ops.assert_non_negative(
    382             labels.values, message=err_msg)
    383         with ops.control_dependencies(
    384             [assert_int, assert_less, assert_greater]):
    385           return math_ops.to_int64(
    386               sparse_ops.sparse_to_indicator(labels, self._n_classes))
    387     err_msg = (
    388         r'labels must be an integer indicator Tensor with values in [0, 1]')
    389     return head_lib._assert_range(labels, 2, message=err_msg)  # pylint:disable=protected-access,
    390 
    391   def create_loss(self, features, mode, logits, labels):
    392     """See `Head`."""
    393     del mode  # Unused for this head.
    394     logits = ops.convert_to_tensor(logits)
    395     processed_labels = self._process_labels(labels)
    396     processed_labels = head_lib._check_dense_labels_match_logits_and_reshape(  # pylint:disable=protected-access
    397         labels=processed_labels, logits=logits,
    398         expected_labels_dimension=self.logits_dimension)
    399     if self._loss_fn:
    400       unweighted_loss = head_lib._call_loss_fn(  # pylint:disable=protected-access
    401           loss_fn=self._loss_fn, labels=processed_labels, logits=logits,
    402           features=features, expected_loss_dim=1)
    403     else:
    404       unweighted_loss = losses.sigmoid_cross_entropy(
    405           multi_class_labels=processed_labels, logits=logits,
    406           reduction=losses.Reduction.NONE)
    407       # Averages loss over classes.
    408       unweighted_loss = math_ops.reduce_mean(
    409           unweighted_loss, axis=-1, keep_dims=True)
    410     weights = head_lib._get_weights_and_check_match_logits(  # pylint:disable=protected-access,
    411         features=features, weight_column=self._weight_column, logits=logits)
    412     training_loss = losses.compute_weighted_loss(
    413         unweighted_loss, weights=weights, reduction=self._loss_reduction)
    414     return head_lib.LossSpec(
    415         training_loss=training_loss,
    416         unreduced_loss=unweighted_loss,
    417         weights=weights,
    418         processed_labels=processed_labels)
    419 
    420   def create_estimator_spec(
    421       self, features, mode, logits, labels=None, train_op_fn=None,
    422       regularization_losses=None):
    423     """Returns an `EstimatorSpec`.
    424 
    425     Args:
    426       features: Input `dict` of `Tensor` or `SparseTensor` objects.
    427       mode: Estimator's `ModeKeys`.
    428       logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`.
    429         For many applications, the shape is `[batch_size, n_classes]`.
    430       labels: Labels with shape matching `logits`. Can be multi-hot `Tensor`
    431         with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with
    432         `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when
    433         `mode` equals `TRAIN` or `EVAL`.
    434       train_op_fn: Function that takes a scalar loss `Tensor` and returns
    435         `train_op`. Required in TRAIN mode.
    436       regularization_losses: A list of additional scalar losses to be added to
    437         the training loss, such as regularization losses. These losses are
    438         usually expressed as a batch average, so for best results users need to
    439         set `loss_reduction=SUM_OVER_BATCH_SIZE` or
    440         `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to
    441         avoid scaling errors.
    442     Returns:
    443       `EstimatorSpec`.
    444     Raises:
    445       ValueError: If `train_op_fn` is `None` in TRAIN mode.
    446     """
    447     with ops.name_scope(self._name, 'head'):
    448       logits = head_lib._check_logits_final_dim(logits, self.logits_dimension)  # pylint:disable=protected-access
    449 
    450       # Predict.
    451       pred_keys = prediction_keys.PredictionKeys
    452       with ops.name_scope(None, 'predictions', (logits,)):
    453         probabilities = math_ops.sigmoid(logits, name=pred_keys.PROBABILITIES)
    454         predictions = {
    455             pred_keys.LOGITS: logits,
    456             pred_keys.PROBABILITIES: probabilities,
    457         }
    458       if mode == model_fn.ModeKeys.PREDICT:
    459         classifier_output = head_lib._classification_output(  # pylint:disable=protected-access
    460             scores=probabilities, n_classes=self._n_classes,
    461             label_vocabulary=self._label_vocabulary)
    462         return model_fn.EstimatorSpec(
    463             mode=model_fn.ModeKeys.PREDICT,
    464             predictions=predictions,
    465             export_outputs={
    466                 _DEFAULT_SERVING_KEY: classifier_output,
    467                 head_lib._CLASSIFY_SERVING_KEY: classifier_output,  # pylint:disable=protected-access
    468                 head_lib._PREDICT_SERVING_KEY: (  # pylint:disable=protected-access
    469                     export_output.PredictOutput(predictions))
    470             })
    471 
    472       (training_loss, unreduced_loss, weights,
    473        processed_labels) = self.create_loss(
    474            features=features, mode=mode, logits=logits, labels=labels)
    475       if regularization_losses:
    476         regularization_loss = math_ops.add_n(regularization_losses)
    477         regularized_training_loss = math_ops.add_n(
    478             [training_loss, regularization_loss])
    479       else:
    480         regularization_loss = None
    481         regularized_training_loss = training_loss
    482 
    483       # Eval.
    484       if mode == model_fn.ModeKeys.EVAL:
    485         return model_fn.EstimatorSpec(
    486             mode=model_fn.ModeKeys.EVAL,
    487             predictions=predictions,
    488             loss=regularized_training_loss,
    489             eval_metric_ops=self._eval_metric_ops(
    490                 labels=processed_labels,
    491                 probabilities=probabilities,
    492                 weights=weights,
    493                 unreduced_loss=unreduced_loss,
    494                 regularization_loss=regularization_loss))
    495 
    496       # Train.
    497       if train_op_fn is None:
    498         raise ValueError('train_op_fn can not be None.')
    499       # Only summarize mean_loss for SUM reduction to preserve backwards
    500       # compatibility. Otherwise skip it to avoid unnecessary computation.
    501       if self._loss_reduction == losses.Reduction.SUM:
    502         example_weight_sum = math_ops.reduce_sum(
    503             weights * array_ops.ones_like(unreduced_loss))
    504         mean_loss = training_loss / example_weight_sum
    505       else:
    506         mean_loss = None
    507     with ops.name_scope(''):
    508       keys = metric_keys.MetricKeys
    509       summary.scalar(
    510           head_lib._summary_key(self._name, keys.LOSS),  # pylint:disable=protected-access
    511           regularized_training_loss)
    512       if mean_loss is not None:
    513         summary.scalar(
    514             head_lib._summary_key(self._name, keys.LOSS_MEAN),  # pylint:disable=protected-access
    515             mean_loss)
    516       if regularization_loss is not None:
    517         summary.scalar(
    518             head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION),  # pylint:disable=protected-access
    519             regularization_loss)
    520     return model_fn.EstimatorSpec(
    521         mode=model_fn.ModeKeys.TRAIN,
    522         predictions=predictions,
    523         loss=regularized_training_loss,
    524         train_op=train_op_fn(regularized_training_loss))
    525 
    526   def _eval_metric_ops(
    527       self, labels, probabilities, weights, unreduced_loss,
    528       regularization_loss):
    529     """Returns a dict of metrics for eval_metric_ops."""
    530     with ops.name_scope(
    531         None, 'metrics',
    532         [labels, probabilities, weights, unreduced_loss, regularization_loss]):
    533       keys = metric_keys.MetricKeys
    534       metric_ops = {
    535           # Estimator already adds a metric for loss.
    536           head_lib._summary_key(self._name, keys.LOSS_MEAN):  # pylint:disable=protected-access
    537               metrics_lib.mean(
    538                   values=unreduced_loss,
    539                   weights=weights,
    540                   name=keys.LOSS_MEAN),
    541           head_lib._summary_key(self._name, keys.AUC):  # pylint:disable=protected-access
    542               metrics_lib.auc(labels=labels, predictions=probabilities,
    543                               weights=weights, name=keys.AUC),
    544           head_lib._summary_key(self._name, keys.AUC_PR):  # pylint:disable=protected-access
    545               metrics_lib.auc(labels=labels, predictions=probabilities,
    546                               weights=weights, curve='PR',
    547                               name=keys.AUC_PR),
    548       }
    549       if regularization_loss is not None:
    550         loss_regularization_key = head_lib._summary_key(  # pylint:disable=protected-access
    551             self._name, keys.LOSS_REGULARIZATION)
    552         metric_ops[loss_regularization_key] = (
    553             metrics_lib.mean(
    554                 values=regularization_loss,
    555                 name=keys.LOSS_REGULARIZATION))
    556       for threshold in self._thresholds:
    557         accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold
    558         metric_ops[head_lib._summary_key(self._name, accuracy_key)] = (  # pylint:disable=protected-access
    559             head_lib._accuracy_at_threshold(  # pylint:disable=protected-access
    560                 labels=labels,
    561                 predictions=probabilities,
    562                 weights=weights,
    563                 threshold=threshold,
    564                 name=accuracy_key))
    565         # Precision for positive examples.
    566         precision_key = keys.PRECISION_AT_THRESHOLD % threshold
    567         metric_ops[head_lib._summary_key(self._name, precision_key)] = (  # pylint:disable=protected-access
    568             head_lib._precision_at_threshold(  # pylint:disable=protected-access
    569                 labels=labels,
    570                 predictions=probabilities,
    571                 weights=weights,
    572                 threshold=threshold,
    573                 name=precision_key))
    574         # Recall for positive examples.
    575         recall_key = keys.RECALL_AT_THRESHOLD % threshold
    576         metric_ops[head_lib._summary_key(self._name, recall_key)] = (  # pylint:disable=protected-access
    577             head_lib._recall_at_threshold(  # pylint:disable=protected-access
    578                 labels=labels,
    579                 predictions=probabilities,
    580                 weights=weights,
    581                 threshold=threshold,
    582                 name=recall_key))
    583     return metric_ops
    584