1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """TargetColumn abstract a single head in the model. 16 """ 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import six 22 23 from tensorflow.contrib.framework import deprecated 24 from tensorflow.contrib.losses.python.losses import loss_ops 25 from tensorflow.contrib.metrics.python.ops import metric_ops 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.ops import control_flow_ops 29 from tensorflow.python.ops import math_ops 30 from tensorflow.python.ops import nn 31 32 33 @deprecated( 34 "2016-11-12", "This file will be removed after the deprecation date." 35 "Please switch to " 36 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 37 def regression_target(label_name=None, 38 weight_column_name=None, 39 label_dimension=1): 40 """Creates a _TargetColumn for linear regression. 41 42 Args: 43 label_name: String, name of the key in label dict. Can be null if label 44 is a tensor (single headed models). 45 weight_column_name: A string defining feature column name representing 46 weights. It is used to down weight or boost examples during training. It 47 will be multiplied by the loss of the example. 48 label_dimension: dimension of the target for multilabels. 49 50 Returns: 51 An instance of _TargetColumn 52 """ 53 return _RegressionTargetColumn( 54 loss_fn=_mean_squared_loss, 55 label_name=label_name, 56 weight_column_name=weight_column_name, 57 label_dimension=label_dimension) 58 59 60 # TODO(zakaria): Add logistic_regression_target 61 62 63 @deprecated( 64 "2016-11-12", "This file will be removed after the deprecation date." 65 "Please switch to " 66 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 67 def multi_class_target(n_classes, label_name=None, weight_column_name=None): 68 """Creates a _TargetColumn for multi class single label classification. 69 70 The target column uses softmax cross entropy loss. 71 72 Args: 73 n_classes: Integer, number of classes, must be >= 2 74 label_name: String, name of the key in label dict. Can be null if label 75 is a tensor (single headed models). 76 weight_column_name: A string defining feature column name representing 77 weights. It is used to down weight or boost examples during training. It 78 will be multiplied by the loss of the example. 79 80 Returns: 81 An instance of _MultiClassTargetColumn. 82 83 Raises: 84 ValueError: if n_classes is < 2 85 """ 86 if n_classes < 2: 87 raise ValueError("n_classes must be > 1 for classification.") 88 if n_classes == 2: 89 loss_fn = _log_loss_with_two_classes 90 else: 91 loss_fn = _softmax_cross_entropy_loss 92 return _MultiClassTargetColumn( 93 loss_fn=loss_fn, 94 n_classes=n_classes, 95 label_name=label_name, 96 weight_column_name=weight_column_name) 97 98 99 @deprecated( 100 "2016-11-12", "This file will be removed after the deprecation date." 101 "Please switch to " 102 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 103 def binary_svm_target(label_name=None, weight_column_name=None): 104 """Creates a _TargetColumn for binary classification with SVMs. 105 106 The target column uses binary hinge loss. 107 108 Args: 109 label_name: String, name of the key in label dict. Can be null if label 110 is a tensor (single headed models). 111 weight_column_name: A string defining feature column name representing 112 weights. It is used to down weight or boost examples during training. It 113 will be multiplied by the loss of the example. 114 115 Returns: 116 An instance of _TargetColumn. 117 118 """ 119 return _BinarySvmTargetColumn( 120 label_name=label_name, weight_column_name=weight_column_name) 121 122 123 @deprecated( 124 "2016-11-12", "This file will be removed after the deprecation date." 125 "Please switch to " 126 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 127 class ProblemType(object): 128 UNSPECIFIED = 0 129 CLASSIFICATION = 1 130 LINEAR_REGRESSION = 2 131 LOGISTIC_REGRESSION = 3 132 133 134 class _TargetColumn(object): 135 """_TargetColumn is the abstraction for a single head in a model. 136 137 Args: 138 loss_fn: a function that returns the loss tensor. 139 num_label_columns: Integer, number of label columns. 140 label_name: String, name of the key in label dict. Can be null if label 141 is a tensor (single headed models). 142 weight_column_name: A string defining feature column name representing 143 weights. It is used to down weight or boost examples during training. It 144 will be multiplied by the loss of the example. 145 146 Raises: 147 ValueError: if loss_fn or n_classes are missing. 148 """ 149 150 def __init__(self, loss_fn, num_label_columns, label_name, weight_column_name, 151 problem_type): 152 if not loss_fn: 153 raise ValueError("loss_fn must be provided") 154 if num_label_columns is None: # n_classes can be 0 155 raise ValueError("num_label_columns must be provided") 156 157 self._loss_fn = loss_fn 158 self._num_label_columns = num_label_columns 159 self._label_name = label_name 160 self._weight_column_name = weight_column_name 161 self._problem_type = problem_type 162 163 def logits_to_predictions(self, logits, proba=False): 164 # Abstrat, Subclasses must implement. 165 raise NotImplementedError() 166 167 def get_eval_ops(self, features, logits, labels, metrics=None): 168 """Returns eval op.""" 169 raise NotImplementedError 170 171 @property 172 def label_name(self): 173 return self._label_name 174 175 @property 176 def weight_column_name(self): 177 return self._weight_column_name 178 179 @property 180 def num_label_columns(self): 181 return self._num_label_columns 182 183 def get_weight_tensor(self, features): 184 if not self._weight_column_name: 185 return None 186 else: 187 return array_ops.reshape( 188 math_ops.to_float(features[self._weight_column_name]), shape=(-1,)) 189 190 @property 191 def problem_type(self): 192 return self._problem_type 193 194 def _weighted_loss(self, loss, weight_tensor): 195 """Returns cumulative weighted loss.""" 196 unweighted_loss = array_ops.reshape(loss, shape=(-1,)) 197 weighted_loss = math_ops.multiply(unweighted_loss, 198 array_ops.reshape( 199 weight_tensor, shape=(-1,))) 200 return weighted_loss 201 202 def training_loss(self, logits, target, features, name="training_loss"): 203 """Returns training loss tensor for this head. 204 205 Training loss is different from the loss reported on the tensorboard as we 206 should respect the example weights when computing the gradient. 207 208 L = sum_{i} w_{i} * l_{i} / B 209 210 where B is the number of examples in the batch, l_{i}, w_{i} are individual 211 losses, and example weight. 212 213 Args: 214 logits: logits, a float tensor. 215 target: either a tensor for labels or in multihead case, a dict of string 216 to target tensor. 217 features: features dict. 218 name: Op name. 219 220 Returns: 221 Loss tensor. 222 """ 223 target = target[self.name] if isinstance(target, dict) else target 224 loss_unweighted = self._loss_fn(logits, target) 225 226 weight_tensor = self.get_weight_tensor(features) 227 if weight_tensor is None: 228 return math_ops.reduce_mean(loss_unweighted, name=name) 229 loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor) 230 return math_ops.reduce_mean(loss_weighted, name=name) 231 232 def loss(self, logits, target, features): 233 """Returns loss tensor for this head. 234 235 The loss returned is the weighted average. 236 237 L = sum_{i} w_{i} * l_{i} / sum_{i} w_{i} 238 239 Args: 240 logits: logits, a float tensor. 241 target: either a tensor for labels or in multihead case, a dict of string 242 to target tensor. 243 features: features dict. 244 245 Returns: 246 Loss tensor. 247 """ 248 target = target[self.name] if isinstance(target, dict) else target 249 loss_unweighted = self._loss_fn(logits, target) 250 251 weight_tensor = self.get_weight_tensor(features) 252 if weight_tensor is None: 253 return math_ops.reduce_mean(loss_unweighted, name="loss") 254 loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor) 255 return math_ops.div(math_ops.reduce_sum(loss_weighted), 256 math_ops.to_float(math_ops.reduce_sum(weight_tensor)), 257 name="loss") 258 259 260 class _RegressionTargetColumn(_TargetColumn): 261 """_TargetColumn for regression.""" 262 263 def __init__(self, loss_fn, label_name, weight_column_name, label_dimension): 264 super(_RegressionTargetColumn, self).__init__( 265 loss_fn=loss_fn, 266 num_label_columns=label_dimension, 267 label_name=label_name, 268 weight_column_name=weight_column_name, 269 problem_type=ProblemType.LINEAR_REGRESSION) 270 271 def logits_to_predictions(self, logits, proba=False): 272 if self.num_label_columns == 1: 273 return array_ops.squeeze(logits, squeeze_dims=[1]) 274 return logits 275 276 def get_eval_ops(self, features, logits, labels, metrics=None): 277 loss = self.loss(logits, labels, features) 278 result = {"loss": metric_ops.streaming_mean(loss)} 279 if metrics: 280 predictions = self.logits_to_predictions(logits, proba=False) 281 result.update( 282 _run_metrics(predictions, labels, metrics, 283 self.get_weight_tensor(features))) 284 return result 285 286 287 class _MultiClassTargetColumn(_TargetColumn): 288 """_TargetColumn for classification.""" 289 290 # TODO(zakaria): support multilabel. 291 def __init__(self, loss_fn, n_classes, label_name, weight_column_name): 292 if n_classes < 2: 293 raise ValueError("n_classes must be >= 2") 294 super(_MultiClassTargetColumn, self).__init__( 295 loss_fn=loss_fn, 296 num_label_columns=1 if n_classes == 2 else n_classes, 297 label_name=label_name, 298 weight_column_name=weight_column_name, 299 problem_type=ProblemType.CLASSIFICATION) 300 301 def logits_to_predictions(self, logits, proba=False): 302 if self.num_label_columns == 1: 303 logits = array_ops.concat([array_ops.zeros_like(logits), logits], 1) 304 305 if proba: 306 return nn.softmax(logits) 307 else: 308 return math_ops.argmax(logits, 1) 309 310 def _default_eval_metrics(self): 311 if self._num_label_columns == 1: 312 return get_default_binary_metrics_for_eval(thresholds=[.5]) 313 return {} 314 315 def get_eval_ops(self, features, logits, labels, metrics=None): 316 loss = self.loss(logits, labels, features) 317 result = {"loss": metric_ops.streaming_mean(loss)} 318 319 # Adds default metrics. 320 if metrics is None: 321 # TODO(b/29366811): This currently results in both an "accuracy" and an 322 # "accuracy/threshold_0.500000_mean" metric for binary classification. 323 metrics = {("accuracy", "classes"): metric_ops.streaming_accuracy} 324 325 predictions = math_ops.sigmoid(logits) 326 labels_float = math_ops.to_float(labels) 327 328 default_metrics = self._default_eval_metrics() 329 for metric_name, metric_op in default_metrics.items(): 330 result[metric_name] = metric_op(predictions, labels_float) 331 332 class_metrics = {} 333 proba_metrics = {} 334 for name, metric_op in six.iteritems(metrics): 335 if isinstance(name, tuple): 336 if len(name) != 2: 337 raise ValueError("Ignoring metric {}. It returned a tuple with " 338 "len {}, expected 2.".format(name, len(name))) 339 else: 340 if name[1] not in ["classes", "probabilities"]: 341 raise ValueError("Ignoring metric {}. The 2nd element of its " 342 "name should be either 'classes' or " 343 "'probabilities'.".format(name)) 344 elif name[1] == "classes": 345 class_metrics[name[0]] = metric_op 346 else: 347 proba_metrics[name[0]] = metric_op 348 elif isinstance(name, str): 349 class_metrics[name] = metric_op 350 else: 351 raise ValueError("Ignoring metric {}. Its name is not in the correct " 352 "form.".format(name)) 353 if class_metrics: 354 class_predictions = self.logits_to_predictions(logits, proba=False) 355 result.update( 356 _run_metrics(class_predictions, labels, class_metrics, 357 self.get_weight_tensor(features))) 358 if proba_metrics: 359 predictions = self.logits_to_predictions(logits, proba=True) 360 result.update( 361 _run_metrics(predictions, labels, proba_metrics, 362 self.get_weight_tensor(features))) 363 return result 364 365 366 class _BinarySvmTargetColumn(_MultiClassTargetColumn): 367 """_TargetColumn for binary classification using SVMs.""" 368 369 def __init__(self, label_name, weight_column_name): 370 371 def loss_fn(logits, target): 372 check_shape_op = control_flow_ops.Assert( 373 math_ops.less_equal(array_ops.rank(target), 2), 374 ["target's shape should be either [batch_size, 1] or [batch_size]"]) 375 with ops.control_dependencies([check_shape_op]): 376 target = array_ops.reshape( 377 target, shape=[array_ops.shape(target)[0], 1]) 378 return loss_ops.hinge_loss(logits, target) 379 380 super(_BinarySvmTargetColumn, self).__init__( 381 loss_fn=loss_fn, 382 n_classes=2, 383 label_name=label_name, 384 weight_column_name=weight_column_name) 385 386 def logits_to_predictions(self, logits, proba=False): 387 if proba: 388 raise ValueError( 389 "logits to probabilities is not supported for _BinarySvmTargetColumn") 390 391 logits = array_ops.concat([array_ops.zeros_like(logits), logits], 1) 392 return math_ops.argmax(logits, 1) 393 394 395 # TODO(zakaria): use contrib losses. 396 def _mean_squared_loss(logits, target): 397 # To prevent broadcasting inside "-". 398 if len(target.get_shape()) == 1: 399 target = array_ops.expand_dims(target, dim=[1]) 400 401 logits.get_shape().assert_is_compatible_with(target.get_shape()) 402 return math_ops.square(logits - math_ops.to_float(target)) 403 404 405 def _log_loss_with_two_classes(logits, target): 406 # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target. 407 if len(target.get_shape()) == 1: 408 target = array_ops.expand_dims(target, dim=[1]) 409 loss_vec = nn.sigmoid_cross_entropy_with_logits( 410 labels=math_ops.to_float(target), logits=logits) 411 return loss_vec 412 413 414 def _softmax_cross_entropy_loss(logits, target): 415 # Check that we got integer for classification. 416 if not target.dtype.is_integer: 417 raise ValueError("Target's dtype should be integer " 418 "Instead got %s." % target.dtype) 419 # sparse_softmax_cross_entropy_with_logits requires [batch_size] target. 420 if len(target.get_shape()) == 2: 421 target = array_ops.squeeze(target, squeeze_dims=[1]) 422 loss_vec = nn.sparse_softmax_cross_entropy_with_logits( 423 labels=target, logits=logits) 424 return loss_vec 425 426 427 def _run_metrics(predictions, labels, metrics, weights): 428 result = {} 429 labels = math_ops.cast(labels, predictions.dtype) 430 for name, metric in six.iteritems(metrics or {}): 431 if weights is not None: 432 result[name] = metric(predictions, labels, weights=weights) 433 else: 434 result[name] = metric(predictions, labels) 435 436 return result 437 438 439 @deprecated( 440 "2016-11-12", "This file will be removed after the deprecation date." 441 "Please switch to " 442 "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py") 443 def get_default_binary_metrics_for_eval(thresholds): 444 """Returns a dictionary of basic metrics for logistic regression. 445 446 Args: 447 thresholds: List of floating point thresholds to use for accuracy, 448 precision, and recall metrics. If None, defaults to [0.5]. 449 450 Returns: 451 Dictionary mapping metrics string names to metrics functions. 452 """ 453 metrics = {} 454 metrics[_MetricKeys.PREDICTION_MEAN] = _predictions_streaming_mean 455 metrics[_MetricKeys.TARGET_MEAN] = _labels_streaming_mean 456 # Also include the streaming mean of the label as an accuracy baseline, as 457 # a reminder to users. 458 metrics[_MetricKeys.ACCURACY_BASELINE] = _labels_streaming_mean 459 460 metrics[_MetricKeys.AUC] = _streaming_auc 461 462 for threshold in thresholds: 463 metrics[_MetricKeys.ACCURACY_MEAN % 464 threshold] = _accuracy_at_threshold(threshold) 465 # Precision for positive examples. 466 metrics[_MetricKeys.PRECISION_MEAN % threshold] = _streaming_at_threshold( 467 metric_ops.streaming_precision_at_thresholds, threshold) 468 # Recall for positive examples. 469 metrics[_MetricKeys.RECALL_MEAN % threshold] = _streaming_at_threshold( 470 metric_ops.streaming_recall_at_thresholds, threshold) 471 472 return metrics 473 474 475 def _float_weights_or_none(weights): 476 if weights is None: 477 return None 478 return math_ops.to_float(weights) 479 480 481 def _labels_streaming_mean(unused_predictions, labels, weights=None): 482 return metric_ops.streaming_mean(labels, weights=weights) 483 484 485 def _predictions_streaming_mean(predictions, unused_labels, weights=None): 486 return metric_ops.streaming_mean(predictions, weights=weights) 487 488 489 def _streaming_auc(predictions, labels, weights=None): 490 return metric_ops.streaming_auc( 491 predictions, labels, weights=_float_weights_or_none(weights)) 492 493 494 def _accuracy_at_threshold(threshold): 495 496 def _accuracy_metric(predictions, labels, weights=None): 497 threshold_predictions = math_ops.to_float( 498 math_ops.greater_equal(predictions, threshold)) 499 return metric_ops.streaming_accuracy( 500 predictions=threshold_predictions, labels=labels, weights=weights) 501 502 return _accuracy_metric 503 504 505 def _streaming_at_threshold(streaming_metrics_fn, threshold): 506 507 def _streaming_metrics(predictions, labels, weights=None): 508 precision_tensor, update_op = streaming_metrics_fn( 509 predictions, 510 labels=labels, 511 thresholds=[threshold], 512 weights=_float_weights_or_none(weights)) 513 return array_ops.squeeze(precision_tensor), update_op 514 515 return _streaming_metrics 516 517 518 class _MetricKeys(object): 519 AUC = "auc" 520 PREDICTION_MEAN = "labels/prediction_mean" 521 TARGET_MEAN = "labels/actual_target_mean" 522 ACCURACY_BASELINE = "accuracy/baseline_target_mean" 523 ACCURACY_MEAN = "accuracy/threshold_%f_mean" 524 PRECISION_MEAN = "precision/positive_threshold_%f_mean" 525 RECALL_MEAN = "recall/positive_threshold_%f_mean" 526