1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Abstractions for the head(s) of a model.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.estimator import model_fn 22 from tensorflow.python.estimator.canned import head as head_lib 23 from tensorflow.python.estimator.canned import metric_keys 24 from tensorflow.python.estimator.canned import prediction_keys 25 from tensorflow.python.estimator.export import export_output 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import ops 28 from tensorflow.python.framework import sparse_tensor 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import check_ops 31 from tensorflow.python.ops import lookup_ops 32 from tensorflow.python.ops import math_ops 33 from tensorflow.python.ops import metrics as metrics_lib 34 from tensorflow.python.ops import sparse_ops 35 from tensorflow.python.ops.losses import losses 36 from tensorflow.python.saved_model import signature_constants 37 from tensorflow.python.summary import summary 38 39 _DEFAULT_SERVING_KEY = signature_constants.DEFAULT_SERVING_SIGNATURE_DEF_KEY 40 41 42 def multi_class_head(n_classes, 43 weight_column=None, 44 label_vocabulary=None, 45 loss_reduction=losses.Reduction.SUM, 46 loss_fn=None, 47 name=None): 48 """Creates a `_Head` for multi class classification. 49 50 Uses `sparse_softmax_cross_entropy` loss. 51 52 The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. 53 In many applications, the shape is `[batch_size, n_classes]`. 54 55 `labels` must be a dense `Tensor` with shape matching `logits`, namely 56 `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string 57 `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, 58 `labels` must be an integer `Tensor` with values specifying the class index. 59 60 If `weight_column` is specified, weights must be of shape 61 `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. 62 63 The loss is the weighted sum over the input dimensions. Namely, if the input 64 labels have shape `[batch_size, 1]`, the loss is the weighted sum over 65 `batch_size`. 66 67 Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or 68 `(labels, logits, features)` as arguments and returns unreduced loss with 69 shape `[D0, D1, ... DN, 1]`. `loss_fn` must support integer `labels` with 70 shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to 71 the input labels before passing them to `loss_fn`. 72 73 Args: 74 n_classes: Number of classes, must be greater than 2 (for 2 classes, use 75 `binary_classification_head`). 76 weight_column: A string or a `_NumericColumn` created by 77 `tf.feature_column.numeric_column` defining feature column representing 78 weights. It is used to down weight or boost examples during training. It 79 will be multiplied by the loss of the example. 80 label_vocabulary: A list or tuple of strings representing possible label 81 values. If it is not given, that means labels are already encoded as an 82 integer within [0, n_classes). If given, labels must be of string type and 83 have any value in `label_vocabulary`. Note that errors will be raised if 84 `label_vocabulary` is not provided but labels are strings. 85 loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to 86 reduce training loss over batch. Defaults to `SUM`. 87 loss_fn: Optional loss function. 88 name: name of the head. If provided, summary and metrics keys will be 89 suffixed by `"/" + name`. Also used as `name_scope` when creating ops. 90 91 Returns: 92 An instance of `_Head` for multi class classification. 93 94 Raises: 95 ValueError: if `n_classes`, `label_vocabulary` or `loss_reduction` is 96 invalid. 97 """ 98 return head_lib._multi_class_head_with_softmax_cross_entropy_loss( # pylint:disable=protected-access 99 n_classes=n_classes, 100 weight_column=weight_column, 101 label_vocabulary=label_vocabulary, 102 loss_reduction=loss_reduction, 103 loss_fn=loss_fn, 104 name=name) 105 106 107 def binary_classification_head( 108 weight_column=None, 109 thresholds=None, 110 label_vocabulary=None, 111 loss_reduction=losses.Reduction.SUM, 112 loss_fn=None, 113 name=None): 114 """Creates a `_Head` for single label binary classification. 115 116 This head uses `sigmoid_cross_entropy_with_logits` loss. 117 118 The head expects `logits` with shape `[D0, D1, ... DN, 1]`. 119 In many applications, the shape is `[batch_size, 1]`. 120 121 `labels` must be a dense `Tensor` with shape matching `logits`, namely 122 `[D0, D1, ... DN, 1]`. If `label_vocabulary` given, `labels` must be a string 123 `Tensor` with values from the vocabulary. If `label_vocabulary` is not given, 124 `labels` must be float `Tensor` with values in the interval `[0, 1]`. 125 126 If `weight_column` is specified, weights must be of shape 127 `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. 128 129 The loss is the weighted sum over the input dimensions. Namely, if the input 130 labels have shape `[batch_size, 1]`, the loss is the weighted sum over 131 `batch_size`. 132 133 Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or 134 `(labels, logits, features)` as arguments and returns unreduced loss with 135 shape `[D0, D1, ... DN, 1]`. `loss_fn` must support float `labels` with 136 shape `[D0, D1, ... DN, 1]`. Namely, the head applies `label_vocabulary` to 137 the input labels before passing them to `loss_fn`. 138 139 Args: 140 weight_column: A string or a `_NumericColumn` created by 141 `tf.feature_column.numeric_column` defining feature column representing 142 weights. It is used to down weight or boost examples during training. It 143 will be multiplied by the loss of the example. 144 thresholds: Iterable of floats in the range `(0, 1)`. For binary 145 classification metrics such as precision and recall, an eval metric is 146 generated for each threshold value. This threshold is applied to the 147 logistic values to determine the binary classification (i.e., above the 148 threshold is `true`, below is `false`. 149 label_vocabulary: A list or tuple of strings representing possible label 150 values. If it is not given, labels must be float with values within 151 [0, 1]. If given, labels must be string type and have any value in 152 `label_vocabulary`. Note that errors will be raised if `label_vocabulary` 153 is not provided but labels are strings. 154 loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to 155 reduce training loss over batch. Defaults to `SUM`. 156 loss_fn: Optional loss function. 157 name: name of the head. If provided, summary and metrics keys will be 158 suffixed by `"/" + name`. Also used as `name_scope` when creating ops. 159 160 Returns: 161 An instance of `_Head` for binary classification. 162 163 Raises: 164 ValueError: If `thresholds` contains a value outside of `(0, 1)`. 165 ValueError: If `loss_reduction` is invalid. 166 """ 167 return head_lib._binary_logistic_head_with_sigmoid_cross_entropy_loss( # pylint:disable=protected-access 168 weight_column=weight_column, 169 thresholds=thresholds, 170 label_vocabulary=label_vocabulary, 171 loss_reduction=loss_reduction, 172 loss_fn=loss_fn, 173 name=name) 174 175 176 def regression_head(weight_column=None, 177 label_dimension=1, 178 loss_reduction=losses.Reduction.SUM, 179 loss_fn=None, 180 name=None): 181 """Creates a `_Head` for regression using the `mean_squared_error` loss. 182 183 The loss is the weighted sum over all input dimensions. Namely, if the input 184 labels have shape `[batch_size, label_dimension]`, the loss is the weighted 185 sum over both `batch_size` and `label_dimension`. 186 187 The head expects `logits` with shape `[D0, D1, ... DN, label_dimension]`. 188 In many applications, the shape is `[batch_size, label_dimension]`. 189 190 The `labels` shape must match `logits`, namely 191 `[D0, D1, ... DN, label_dimension]`. If `label_dimension=1`, shape 192 `[D0, D1, ... DN]` is also supported. 193 194 If `weight_column` is specified, weights must be of shape 195 `[D0, D1, ... DN]`, `[D0, D1, ... DN, 1]` or 196 `[D0, D1, ... DN, label_dimension]`. 197 198 Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or 199 `(labels, logits, features)` as arguments and returns unreduced loss with 200 shape `[D0, D1, ... DN, label_dimension]`. 201 202 Args: 203 weight_column: A string or a `_NumericColumn` created by 204 `tf.feature_column.numeric_column` defining feature column representing 205 weights. It is used to down weight or boost examples during training. It 206 will be multiplied by the loss of the example. 207 label_dimension: Number of regression labels per example. This is the size 208 of the last dimension of the labels `Tensor` (typically, this has shape 209 `[batch_size, label_dimension]`). 210 loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to 211 reduce training loss over batch. Defaults to `SUM`. 212 loss_fn: Optional loss function. 213 name: name of the head. If provided, summary and metrics keys will be 214 suffixed by `"/" + name`. Also used as `name_scope` when creating ops. 215 216 Returns: 217 An instance of `_Head` for linear regression. 218 219 Raises: 220 ValueError: If `label_dimension` or `loss_reduction` is invalid. 221 """ 222 return head_lib._regression_head_with_mean_squared_error_loss( # pylint:disable=protected-access 223 weight_column=weight_column, 224 label_dimension=label_dimension, 225 loss_reduction=loss_reduction, 226 loss_fn=loss_fn, 227 name=name) 228 229 230 def multi_label_head(n_classes, 231 weight_column=None, 232 thresholds=None, 233 label_vocabulary=None, 234 loss_reduction=losses.Reduction.SUM, 235 loss_fn=None, 236 name=None): 237 """Creates a `_Head` for multi-label classification. 238 239 Multi-label classification handles the case where each example may have zero 240 or more associated labels, from a discrete set. This is distinct from 241 `multi_class_head` which has exactly one label per example. 242 243 Uses `sigmoid_cross_entropy` loss average over classes and weighted sum over 244 the batch. Namely, if the input logits have shape `[batch_size, n_classes]`, 245 the loss is the average over `n_classes` and the weighted sum over 246 `batch_size`. 247 248 The head expects `logits` with shape `[D0, D1, ... DN, n_classes]`. In many 249 applications, the shape is `[batch_size, n_classes]`. 250 251 Labels can be: 252 * A multi-hot tensor of shape `[D0, D1, ... DN, n_classes]` 253 * An integer `SparseTensor` of class indices. The `dense_shape` must be 254 `[D0, D1, ... DN, ?]` and the values within `[0, n_classes)`. 255 * If `label_vocabulary` is given, a string `SparseTensor`. The `dense_shape` 256 must be `[D0, D1, ... DN, ?]` and the values within `label_vocabulary`. 257 258 If `weight_column` is specified, weights must be of shape 259 `[D0, D1, ... DN]`, or `[D0, D1, ... DN, 1]`. 260 261 Also supports custom `loss_fn`. `loss_fn` takes `(labels, logits)` or 262 `(labels, logits, features)` as arguments and returns unreduced loss with 263 shape `[D0, D1, ... DN, 1]`. `loss_fn` must support indicator `labels` with 264 shape `[D0, D1, ... DN, n_classes]`. Namely, the head applies 265 `label_vocabulary` to the input labels before passing them to `loss_fn`. 266 267 Args: 268 n_classes: Number of classes, must be greater than 1 (for 1 class, use 269 `binary_classification_head`). 270 weight_column: A string or a `_NumericColumn` created by 271 `tf.feature_column.numeric_column` defining feature column representing 272 weights. It is used to down weight or boost examples during training. It 273 will be multiplied by the loss of the example. Per-class weighting is 274 not supported. 275 thresholds: Iterable of floats in the range `(0, 1)`. Accuracy, precision 276 and recall metrics are evaluated for each threshold value. The threshold 277 is applied to the predicted probabilities, i.e. above the threshold is 278 `true`, below is `false`. 279 label_vocabulary: A list of strings represents possible label values. If it 280 is not given, that means labels are already encoded as integer within 281 [0, n_classes) or multi-hot Tensor. If given, labels must be SparseTensor 282 string type and have any value in `label_vocabulary`. Also there will be 283 errors if vocabulary is not provided and labels are string. 284 loss_reduction: One of `tf.losses.Reduction` except `NONE`. Describes how to 285 reduce training loss over batch. Defaults to `SUM`. 286 loss_fn: Optional loss function. 287 name: name of the head. If provided, summary and metrics keys will be 288 suffixed by `"/" + name`. Also used as `name_scope` when creating ops. 289 290 Returns: 291 An instance of `_Head` for multi-label classification. 292 293 Raises: 294 ValueError: if `n_classes`, `thresholds`, `loss_reduction` or `loss_fn` is 295 invalid. 296 """ 297 thresholds = tuple(thresholds) if thresholds else tuple() 298 if n_classes is None or n_classes < 2: 299 raise ValueError( 300 'n_classes must be > 1 for multi-class classification. ' 301 'Given: {}'.format(n_classes)) 302 for threshold in thresholds: 303 if (threshold <= 0.0) or (threshold >= 1.0): 304 raise ValueError( 305 'thresholds must be in (0, 1) range. Given: {}'.format(threshold)) 306 if label_vocabulary is not None: 307 if not isinstance(label_vocabulary, (list, tuple)): 308 raise ValueError( 309 'label_vocabulary must be a list or tuple. ' 310 'Given type: {}'.format(type(label_vocabulary))) 311 if len(label_vocabulary) != n_classes: 312 raise ValueError( 313 'Length of label_vocabulary must be n_classes ({}). ' 314 'Given: {}'.format(n_classes, len(label_vocabulary))) 315 if loss_fn: 316 head_lib._validate_loss_fn_args(loss_fn) # pylint:disable=protected-access 317 if (loss_reduction not in losses.Reduction.all() or 318 loss_reduction == losses.Reduction.NONE): 319 raise ValueError('Invalid loss_reduction: {}'.format(loss_reduction)) 320 return _MultiLabelHead( 321 n_classes=n_classes, weight_column=weight_column, thresholds=thresholds, 322 label_vocabulary=label_vocabulary, loss_reduction=loss_reduction, 323 loss_fn=loss_fn, name=name) 324 325 326 class _MultiLabelHead(head_lib._Head): # pylint:disable=protected-access 327 """`_Head` for multi-label classification.""" 328 329 def __init__(self, 330 n_classes, 331 weight_column=None, 332 thresholds=None, 333 label_vocabulary=None, 334 loss_reduction=losses.Reduction.SUM, 335 loss_fn=None, 336 name=None): 337 self._n_classes = n_classes 338 self._weight_column = weight_column 339 self._thresholds = thresholds 340 self._label_vocabulary = label_vocabulary 341 self._loss_reduction = loss_reduction 342 self._loss_fn = loss_fn 343 self._name = name 344 345 @property 346 def name(self): 347 return self._name 348 349 @property 350 def logits_dimension(self): 351 return self._n_classes 352 353 def _process_labels(self, labels): 354 if labels is None: 355 raise ValueError( 356 'You must provide a labels Tensor. Given: None. ' 357 'Suggested troubleshooting steps: Check that your data contain ' 358 'your label feature. Check that your input_fn properly parses and ' 359 'returns labels.') 360 if isinstance(labels, sparse_tensor.SparseTensor): 361 if labels.dtype == dtypes.string: 362 label_ids_values = lookup_ops.index_table_from_tensor( 363 vocabulary_list=tuple(self._label_vocabulary), 364 name='class_id_lookup').lookup(labels.values) 365 label_ids = sparse_tensor.SparseTensor( 366 indices=labels.indices, 367 values=label_ids_values, 368 dense_shape=labels.dense_shape) 369 return math_ops.to_int64( 370 sparse_ops.sparse_to_indicator(label_ids, self._n_classes)) 371 else: 372 err_msg = ( 373 r'labels must be an integer SparseTensor with values in ' 374 r'[0, {})'.format(self._n_classes)) 375 assert_int = check_ops.assert_integer( 376 labels.values, message=err_msg) 377 assert_less = check_ops.assert_less( 378 labels.values, 379 ops.convert_to_tensor(self._n_classes, dtype=labels.dtype), 380 message=err_msg) 381 assert_greater = check_ops.assert_non_negative( 382 labels.values, message=err_msg) 383 with ops.control_dependencies( 384 [assert_int, assert_less, assert_greater]): 385 return math_ops.to_int64( 386 sparse_ops.sparse_to_indicator(labels, self._n_classes)) 387 err_msg = ( 388 r'labels must be an integer indicator Tensor with values in [0, 1]') 389 return head_lib._assert_range(labels, 2, message=err_msg) # pylint:disable=protected-access, 390 391 def create_loss(self, features, mode, logits, labels): 392 """See `Head`.""" 393 del mode # Unused for this head. 394 logits = ops.convert_to_tensor(logits) 395 processed_labels = self._process_labels(labels) 396 processed_labels = head_lib._check_dense_labels_match_logits_and_reshape( # pylint:disable=protected-access 397 labels=processed_labels, logits=logits, 398 expected_labels_dimension=self.logits_dimension) 399 if self._loss_fn: 400 unweighted_loss = head_lib._call_loss_fn( # pylint:disable=protected-access 401 loss_fn=self._loss_fn, labels=processed_labels, logits=logits, 402 features=features, expected_loss_dim=1) 403 else: 404 unweighted_loss = losses.sigmoid_cross_entropy( 405 multi_class_labels=processed_labels, logits=logits, 406 reduction=losses.Reduction.NONE) 407 # Averages loss over classes. 408 unweighted_loss = math_ops.reduce_mean( 409 unweighted_loss, axis=-1, keep_dims=True) 410 weights = head_lib._get_weights_and_check_match_logits( # pylint:disable=protected-access, 411 features=features, weight_column=self._weight_column, logits=logits) 412 training_loss = losses.compute_weighted_loss( 413 unweighted_loss, weights=weights, reduction=self._loss_reduction) 414 return head_lib.LossSpec( 415 training_loss=training_loss, 416 unreduced_loss=unweighted_loss, 417 weights=weights, 418 processed_labels=processed_labels) 419 420 def create_estimator_spec( 421 self, features, mode, logits, labels=None, train_op_fn=None, 422 regularization_losses=None): 423 """Returns an `EstimatorSpec`. 424 425 Args: 426 features: Input `dict` of `Tensor` or `SparseTensor` objects. 427 mode: Estimator's `ModeKeys`. 428 logits: logits `Tensor` with shape `[D0, D1, ... DN, n_classes]`. 429 For many applications, the shape is `[batch_size, n_classes]`. 430 labels: Labels with shape matching `logits`. Can be multi-hot `Tensor` 431 with shape `[D0, D1, ... DN, n_classes]` or `SparseTensor` with 432 `dense_shape` `[D0, D1, ... DN, ?]`. `labels` is required argument when 433 `mode` equals `TRAIN` or `EVAL`. 434 train_op_fn: Function that takes a scalar loss `Tensor` and returns 435 `train_op`. Required in TRAIN mode. 436 regularization_losses: A list of additional scalar losses to be added to 437 the training loss, such as regularization losses. These losses are 438 usually expressed as a batch average, so for best results users need to 439 set `loss_reduction=SUM_OVER_BATCH_SIZE` or 440 `loss_reduction=SUM_OVER_NONZERO_WEIGHTS` when creating the head to 441 avoid scaling errors. 442 Returns: 443 `EstimatorSpec`. 444 Raises: 445 ValueError: If `train_op_fn` is `None` in TRAIN mode. 446 """ 447 with ops.name_scope(self._name, 'head'): 448 logits = head_lib._check_logits_final_dim(logits, self.logits_dimension) # pylint:disable=protected-access 449 450 # Predict. 451 pred_keys = prediction_keys.PredictionKeys 452 with ops.name_scope(None, 'predictions', (logits,)): 453 probabilities = math_ops.sigmoid(logits, name=pred_keys.PROBABILITIES) 454 predictions = { 455 pred_keys.LOGITS: logits, 456 pred_keys.PROBABILITIES: probabilities, 457 } 458 if mode == model_fn.ModeKeys.PREDICT: 459 classifier_output = head_lib._classification_output( # pylint:disable=protected-access 460 scores=probabilities, n_classes=self._n_classes, 461 label_vocabulary=self._label_vocabulary) 462 return model_fn.EstimatorSpec( 463 mode=model_fn.ModeKeys.PREDICT, 464 predictions=predictions, 465 export_outputs={ 466 _DEFAULT_SERVING_KEY: classifier_output, 467 head_lib._CLASSIFY_SERVING_KEY: classifier_output, # pylint:disable=protected-access 468 head_lib._PREDICT_SERVING_KEY: ( # pylint:disable=protected-access 469 export_output.PredictOutput(predictions)) 470 }) 471 472 (training_loss, unreduced_loss, weights, 473 processed_labels) = self.create_loss( 474 features=features, mode=mode, logits=logits, labels=labels) 475 if regularization_losses: 476 regularization_loss = math_ops.add_n(regularization_losses) 477 regularized_training_loss = math_ops.add_n( 478 [training_loss, regularization_loss]) 479 else: 480 regularization_loss = None 481 regularized_training_loss = training_loss 482 483 # Eval. 484 if mode == model_fn.ModeKeys.EVAL: 485 return model_fn.EstimatorSpec( 486 mode=model_fn.ModeKeys.EVAL, 487 predictions=predictions, 488 loss=regularized_training_loss, 489 eval_metric_ops=self._eval_metric_ops( 490 labels=processed_labels, 491 probabilities=probabilities, 492 weights=weights, 493 unreduced_loss=unreduced_loss, 494 regularization_loss=regularization_loss)) 495 496 # Train. 497 if train_op_fn is None: 498 raise ValueError('train_op_fn can not be None.') 499 # Only summarize mean_loss for SUM reduction to preserve backwards 500 # compatibility. Otherwise skip it to avoid unnecessary computation. 501 if self._loss_reduction == losses.Reduction.SUM: 502 example_weight_sum = math_ops.reduce_sum( 503 weights * array_ops.ones_like(unreduced_loss)) 504 mean_loss = training_loss / example_weight_sum 505 else: 506 mean_loss = None 507 with ops.name_scope(''): 508 keys = metric_keys.MetricKeys 509 summary.scalar( 510 head_lib._summary_key(self._name, keys.LOSS), # pylint:disable=protected-access 511 regularized_training_loss) 512 if mean_loss is not None: 513 summary.scalar( 514 head_lib._summary_key(self._name, keys.LOSS_MEAN), # pylint:disable=protected-access 515 mean_loss) 516 if regularization_loss is not None: 517 summary.scalar( 518 head_lib._summary_key(self._name, keys.LOSS_REGULARIZATION), # pylint:disable=protected-access 519 regularization_loss) 520 return model_fn.EstimatorSpec( 521 mode=model_fn.ModeKeys.TRAIN, 522 predictions=predictions, 523 loss=regularized_training_loss, 524 train_op=train_op_fn(regularized_training_loss)) 525 526 def _eval_metric_ops( 527 self, labels, probabilities, weights, unreduced_loss, 528 regularization_loss): 529 """Returns a dict of metrics for eval_metric_ops.""" 530 with ops.name_scope( 531 None, 'metrics', 532 [labels, probabilities, weights, unreduced_loss, regularization_loss]): 533 keys = metric_keys.MetricKeys 534 metric_ops = { 535 # Estimator already adds a metric for loss. 536 head_lib._summary_key(self._name, keys.LOSS_MEAN): # pylint:disable=protected-access 537 metrics_lib.mean( 538 values=unreduced_loss, 539 weights=weights, 540 name=keys.LOSS_MEAN), 541 head_lib._summary_key(self._name, keys.AUC): # pylint:disable=protected-access 542 metrics_lib.auc(labels=labels, predictions=probabilities, 543 weights=weights, name=keys.AUC), 544 head_lib._summary_key(self._name, keys.AUC_PR): # pylint:disable=protected-access 545 metrics_lib.auc(labels=labels, predictions=probabilities, 546 weights=weights, curve='PR', 547 name=keys.AUC_PR), 548 } 549 if regularization_loss is not None: 550 loss_regularization_key = head_lib._summary_key( # pylint:disable=protected-access 551 self._name, keys.LOSS_REGULARIZATION) 552 metric_ops[loss_regularization_key] = ( 553 metrics_lib.mean( 554 values=regularization_loss, 555 name=keys.LOSS_REGULARIZATION)) 556 for threshold in self._thresholds: 557 accuracy_key = keys.ACCURACY_AT_THRESHOLD % threshold 558 metric_ops[head_lib._summary_key(self._name, accuracy_key)] = ( # pylint:disable=protected-access 559 head_lib._accuracy_at_threshold( # pylint:disable=protected-access 560 labels=labels, 561 predictions=probabilities, 562 weights=weights, 563 threshold=threshold, 564 name=accuracy_key)) 565 # Precision for positive examples. 566 precision_key = keys.PRECISION_AT_THRESHOLD % threshold 567 metric_ops[head_lib._summary_key(self._name, precision_key)] = ( # pylint:disable=protected-access 568 head_lib._precision_at_threshold( # pylint:disable=protected-access 569 labels=labels, 570 predictions=probabilities, 571 weights=weights, 572 threshold=threshold, 573 name=precision_key)) 574 # Recall for positive examples. 575 recall_key = keys.RECALL_AT_THRESHOLD % threshold 576 metric_ops[head_lib._summary_key(self._name, recall_key)] = ( # pylint:disable=protected-access 577 head_lib._recall_at_threshold( # pylint:disable=protected-access 578 labels=labels, 579 predictions=probabilities, 580 weights=weights, 581 threshold=threshold, 582 name=recall_key)) 583 return metric_ops 584