Home | History | Annotate | Download | only in layers
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """TargetColumn abstract a single head in the model.
     16 """
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import six
     22 
     23 from tensorflow.contrib.framework import deprecated
     24 from tensorflow.contrib.losses.python.losses import loss_ops
     25 from tensorflow.contrib.metrics.python.ops import metric_ops
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.ops import control_flow_ops
     29 from tensorflow.python.ops import math_ops
     30 from tensorflow.python.ops import nn
     31 
     32 
     33 @deprecated(
     34     "2016-11-12", "This file will be removed after the deprecation date."
     35     "Please switch to "
     36     "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
     37 def regression_target(label_name=None,
     38                       weight_column_name=None,
     39                       label_dimension=1):
     40   """Creates a _TargetColumn for linear regression.
     41 
     42   Args:
     43     label_name: String, name of the key in label dict. Can be null if label
     44         is a tensor (single headed models).
     45     weight_column_name: A string defining feature column name representing
     46       weights. It is used to down weight or boost examples during training. It
     47       will be multiplied by the loss of the example.
     48     label_dimension: dimension of the target for multilabels.
     49 
     50   Returns:
     51     An instance of _TargetColumn
     52   """
     53   return _RegressionTargetColumn(
     54       loss_fn=_mean_squared_loss,
     55       label_name=label_name,
     56       weight_column_name=weight_column_name,
     57       label_dimension=label_dimension)
     58 
     59 
     60 # TODO(zakaria): Add logistic_regression_target
     61 
     62 
     63 @deprecated(
     64     "2016-11-12", "This file will be removed after the deprecation date."
     65     "Please switch to "
     66     "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
     67 def multi_class_target(n_classes, label_name=None, weight_column_name=None):
     68   """Creates a _TargetColumn for multi class single label classification.
     69 
     70   The target column uses softmax cross entropy loss.
     71 
     72   Args:
     73     n_classes: Integer, number of classes, must be >= 2
     74     label_name: String, name of the key in label dict. Can be null if label
     75         is a tensor (single headed models).
     76     weight_column_name: A string defining feature column name representing
     77       weights. It is used to down weight or boost examples during training. It
     78       will be multiplied by the loss of the example.
     79 
     80   Returns:
     81     An instance of _MultiClassTargetColumn.
     82 
     83   Raises:
     84     ValueError: if n_classes is < 2
     85   """
     86   if n_classes < 2:
     87     raise ValueError("n_classes must be > 1 for classification.")
     88   if n_classes == 2:
     89     loss_fn = _log_loss_with_two_classes
     90   else:
     91     loss_fn = _softmax_cross_entropy_loss
     92   return _MultiClassTargetColumn(
     93       loss_fn=loss_fn,
     94       n_classes=n_classes,
     95       label_name=label_name,
     96       weight_column_name=weight_column_name)
     97 
     98 
     99 @deprecated(
    100     "2016-11-12", "This file will be removed after the deprecation date."
    101     "Please switch to "
    102     "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
    103 def binary_svm_target(label_name=None, weight_column_name=None):
    104   """Creates a _TargetColumn for binary classification with SVMs.
    105 
    106   The target column uses binary hinge loss.
    107 
    108   Args:
    109     label_name: String, name of the key in label dict. Can be null if label
    110       is a tensor (single headed models).
    111     weight_column_name: A string defining feature column name representing
    112       weights. It is used to down weight or boost examples during training. It
    113       will be multiplied by the loss of the example.
    114 
    115   Returns:
    116     An instance of _TargetColumn.
    117 
    118   """
    119   return _BinarySvmTargetColumn(
    120       label_name=label_name, weight_column_name=weight_column_name)
    121 
    122 
    123 @deprecated(
    124     "2016-11-12", "This file will be removed after the deprecation date."
    125     "Please switch to "
    126     "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
    127 class ProblemType(object):
    128   UNSPECIFIED = 0
    129   CLASSIFICATION = 1
    130   LINEAR_REGRESSION = 2
    131   LOGISTIC_REGRESSION = 3
    132 
    133 
    134 class _TargetColumn(object):
    135   """_TargetColumn is the abstraction for a single head in a model.
    136 
    137     Args:
    138       loss_fn: a function that returns the loss tensor.
    139       num_label_columns: Integer, number of label columns.
    140       label_name: String, name of the key in label dict. Can be null if label
    141           is a tensor (single headed models).
    142       weight_column_name: A string defining feature column name representing
    143         weights. It is used to down weight or boost examples during training. It
    144         will be multiplied by the loss of the example.
    145 
    146     Raises:
    147       ValueError: if loss_fn or n_classes are missing.
    148   """
    149 
    150   def __init__(self, loss_fn, num_label_columns, label_name, weight_column_name,
    151                problem_type):
    152     if not loss_fn:
    153       raise ValueError("loss_fn must be provided")
    154     if num_label_columns is None:  # n_classes can be 0
    155       raise ValueError("num_label_columns must be provided")
    156 
    157     self._loss_fn = loss_fn
    158     self._num_label_columns = num_label_columns
    159     self._label_name = label_name
    160     self._weight_column_name = weight_column_name
    161     self._problem_type = problem_type
    162 
    163   def logits_to_predictions(self, logits, proba=False):
    164     # Abstrat, Subclasses must implement.
    165     raise NotImplementedError()
    166 
    167   def get_eval_ops(self, features, logits, labels, metrics=None):
    168     """Returns eval op."""
    169     raise NotImplementedError
    170 
    171   @property
    172   def label_name(self):
    173     return self._label_name
    174 
    175   @property
    176   def weight_column_name(self):
    177     return self._weight_column_name
    178 
    179   @property
    180   def num_label_columns(self):
    181     return self._num_label_columns
    182 
    183   def get_weight_tensor(self, features):
    184     if not self._weight_column_name:
    185       return None
    186     else:
    187       return array_ops.reshape(
    188           math_ops.to_float(features[self._weight_column_name]), shape=(-1,))
    189 
    190   @property
    191   def problem_type(self):
    192     return self._problem_type
    193 
    194   def _weighted_loss(self, loss, weight_tensor):
    195     """Returns cumulative weighted loss."""
    196     unweighted_loss = array_ops.reshape(loss, shape=(-1,))
    197     weighted_loss = math_ops.multiply(unweighted_loss,
    198                                       array_ops.reshape(
    199                                           weight_tensor, shape=(-1,)))
    200     return weighted_loss
    201 
    202   def training_loss(self, logits, target, features, name="training_loss"):
    203     """Returns training loss tensor for this head.
    204 
    205     Training loss is different from the loss reported on the tensorboard as we
    206     should respect the example weights when computing the gradient.
    207 
    208       L = sum_{i} w_{i} * l_{i} / B
    209 
    210     where B is the number of examples in the batch, l_{i}, w_{i} are individual
    211     losses, and example weight.
    212 
    213     Args:
    214       logits: logits, a float tensor.
    215       target: either a tensor for labels or in multihead case, a dict of string
    216         to target tensor.
    217       features: features dict.
    218       name: Op name.
    219 
    220     Returns:
    221       Loss tensor.
    222     """
    223     target = target[self.name] if isinstance(target, dict) else target
    224     loss_unweighted = self._loss_fn(logits, target)
    225 
    226     weight_tensor = self.get_weight_tensor(features)
    227     if weight_tensor is None:
    228       return math_ops.reduce_mean(loss_unweighted, name=name)
    229     loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
    230     return math_ops.reduce_mean(loss_weighted, name=name)
    231 
    232   def loss(self, logits, target, features):
    233     """Returns loss tensor for this head.
    234 
    235     The loss returned is the weighted average.
    236 
    237       L = sum_{i} w_{i} * l_{i} / sum_{i} w_{i}
    238 
    239     Args:
    240       logits: logits, a float tensor.
    241       target: either a tensor for labels or in multihead case, a dict of string
    242         to target tensor.
    243       features: features dict.
    244 
    245     Returns:
    246       Loss tensor.
    247     """
    248     target = target[self.name] if isinstance(target, dict) else target
    249     loss_unweighted = self._loss_fn(logits, target)
    250 
    251     weight_tensor = self.get_weight_tensor(features)
    252     if weight_tensor is None:
    253       return math_ops.reduce_mean(loss_unweighted, name="loss")
    254     loss_weighted = self._weighted_loss(loss_unweighted, weight_tensor)
    255     return math_ops.div(math_ops.reduce_sum(loss_weighted),
    256                         math_ops.to_float(math_ops.reduce_sum(weight_tensor)),
    257                         name="loss")
    258 
    259 
    260 class _RegressionTargetColumn(_TargetColumn):
    261   """_TargetColumn for regression."""
    262 
    263   def __init__(self, loss_fn, label_name, weight_column_name, label_dimension):
    264     super(_RegressionTargetColumn, self).__init__(
    265         loss_fn=loss_fn,
    266         num_label_columns=label_dimension,
    267         label_name=label_name,
    268         weight_column_name=weight_column_name,
    269         problem_type=ProblemType.LINEAR_REGRESSION)
    270 
    271   def logits_to_predictions(self, logits, proba=False):
    272     if self.num_label_columns == 1:
    273       return array_ops.squeeze(logits, squeeze_dims=[1])
    274     return logits
    275 
    276   def get_eval_ops(self, features, logits, labels, metrics=None):
    277     loss = self.loss(logits, labels, features)
    278     result = {"loss": metric_ops.streaming_mean(loss)}
    279     if metrics:
    280       predictions = self.logits_to_predictions(logits, proba=False)
    281       result.update(
    282           _run_metrics(predictions, labels, metrics,
    283                        self.get_weight_tensor(features)))
    284     return result
    285 
    286 
    287 class _MultiClassTargetColumn(_TargetColumn):
    288   """_TargetColumn for classification."""
    289 
    290   # TODO(zakaria): support multilabel.
    291   def __init__(self, loss_fn, n_classes, label_name, weight_column_name):
    292     if n_classes < 2:
    293       raise ValueError("n_classes must be >= 2")
    294     super(_MultiClassTargetColumn, self).__init__(
    295         loss_fn=loss_fn,
    296         num_label_columns=1 if n_classes == 2 else n_classes,
    297         label_name=label_name,
    298         weight_column_name=weight_column_name,
    299         problem_type=ProblemType.CLASSIFICATION)
    300 
    301   def logits_to_predictions(self, logits, proba=False):
    302     if self.num_label_columns == 1:
    303       logits = array_ops.concat([array_ops.zeros_like(logits), logits], 1)
    304 
    305     if proba:
    306       return nn.softmax(logits)
    307     else:
    308       return math_ops.argmax(logits, 1)
    309 
    310   def _default_eval_metrics(self):
    311     if self._num_label_columns == 1:
    312       return get_default_binary_metrics_for_eval(thresholds=[.5])
    313     return {}
    314 
    315   def get_eval_ops(self, features, logits, labels, metrics=None):
    316     loss = self.loss(logits, labels, features)
    317     result = {"loss": metric_ops.streaming_mean(loss)}
    318 
    319     # Adds default metrics.
    320     if metrics is None:
    321       # TODO(b/29366811): This currently results in both an "accuracy" and an
    322       # "accuracy/threshold_0.500000_mean" metric for binary classification.
    323       metrics = {("accuracy", "classes"): metric_ops.streaming_accuracy}
    324 
    325     predictions = math_ops.sigmoid(logits)
    326     labels_float = math_ops.to_float(labels)
    327 
    328     default_metrics = self._default_eval_metrics()
    329     for metric_name, metric_op in default_metrics.items():
    330       result[metric_name] = metric_op(predictions, labels_float)
    331 
    332     class_metrics = {}
    333     proba_metrics = {}
    334     for name, metric_op in six.iteritems(metrics):
    335       if isinstance(name, tuple):
    336         if len(name) != 2:
    337           raise ValueError("Ignoring metric {}. It returned a tuple with "
    338                            "len {}, expected 2.".format(name, len(name)))
    339         else:
    340           if name[1] not in ["classes", "probabilities"]:
    341             raise ValueError("Ignoring metric {}. The 2nd element of its "
    342                              "name should be either 'classes' or "
    343                              "'probabilities'.".format(name))
    344           elif name[1] == "classes":
    345             class_metrics[name[0]] = metric_op
    346           else:
    347             proba_metrics[name[0]] = metric_op
    348       elif isinstance(name, str):
    349         class_metrics[name] = metric_op
    350       else:
    351         raise ValueError("Ignoring metric {}. Its name is not in the correct "
    352                          "form.".format(name))
    353     if class_metrics:
    354       class_predictions = self.logits_to_predictions(logits, proba=False)
    355       result.update(
    356           _run_metrics(class_predictions, labels, class_metrics,
    357                        self.get_weight_tensor(features)))
    358     if proba_metrics:
    359       predictions = self.logits_to_predictions(logits, proba=True)
    360       result.update(
    361           _run_metrics(predictions, labels, proba_metrics,
    362                        self.get_weight_tensor(features)))
    363     return result
    364 
    365 
    366 class _BinarySvmTargetColumn(_MultiClassTargetColumn):
    367   """_TargetColumn for binary classification using SVMs."""
    368 
    369   def __init__(self, label_name, weight_column_name):
    370 
    371     def loss_fn(logits, target):
    372       check_shape_op = control_flow_ops.Assert(
    373           math_ops.less_equal(array_ops.rank(target), 2),
    374           ["target's shape should be either [batch_size, 1] or [batch_size]"])
    375       with ops.control_dependencies([check_shape_op]):
    376         target = array_ops.reshape(
    377             target, shape=[array_ops.shape(target)[0], 1])
    378       return loss_ops.hinge_loss(logits, target)
    379 
    380     super(_BinarySvmTargetColumn, self).__init__(
    381         loss_fn=loss_fn,
    382         n_classes=2,
    383         label_name=label_name,
    384         weight_column_name=weight_column_name)
    385 
    386   def logits_to_predictions(self, logits, proba=False):
    387     if proba:
    388       raise ValueError(
    389           "logits to probabilities is not supported for _BinarySvmTargetColumn")
    390 
    391     logits = array_ops.concat([array_ops.zeros_like(logits), logits], 1)
    392     return math_ops.argmax(logits, 1)
    393 
    394 
    395 # TODO(zakaria): use contrib losses.
    396 def _mean_squared_loss(logits, target):
    397   # To prevent broadcasting inside "-".
    398   if len(target.get_shape()) == 1:
    399     target = array_ops.expand_dims(target, dim=[1])
    400 
    401   logits.get_shape().assert_is_compatible_with(target.get_shape())
    402   return math_ops.square(logits - math_ops.to_float(target))
    403 
    404 
    405 def _log_loss_with_two_classes(logits, target):
    406   # sigmoid_cross_entropy_with_logits requires [batch_size, 1] target.
    407   if len(target.get_shape()) == 1:
    408     target = array_ops.expand_dims(target, dim=[1])
    409   loss_vec = nn.sigmoid_cross_entropy_with_logits(
    410       labels=math_ops.to_float(target), logits=logits)
    411   return loss_vec
    412 
    413 
    414 def _softmax_cross_entropy_loss(logits, target):
    415   # Check that we got integer for classification.
    416   if not target.dtype.is_integer:
    417     raise ValueError("Target's dtype should be integer "
    418                      "Instead got %s." % target.dtype)
    419   # sparse_softmax_cross_entropy_with_logits requires [batch_size] target.
    420   if len(target.get_shape()) == 2:
    421     target = array_ops.squeeze(target, squeeze_dims=[1])
    422   loss_vec = nn.sparse_softmax_cross_entropy_with_logits(
    423       labels=target, logits=logits)
    424   return loss_vec
    425 
    426 
    427 def _run_metrics(predictions, labels, metrics, weights):
    428   result = {}
    429   labels = math_ops.cast(labels, predictions.dtype)
    430   for name, metric in six.iteritems(metrics or {}):
    431     if weights is not None:
    432       result[name] = metric(predictions, labels, weights=weights)
    433     else:
    434       result[name] = metric(predictions, labels)
    435 
    436   return result
    437 
    438 
    439 @deprecated(
    440     "2016-11-12", "This file will be removed after the deprecation date."
    441     "Please switch to "
    442     "third_party/tensorflow/contrib/learn/python/learn/estimators/head.py")
    443 def get_default_binary_metrics_for_eval(thresholds):
    444   """Returns a dictionary of basic metrics for logistic regression.
    445 
    446   Args:
    447     thresholds: List of floating point thresholds to use for accuracy,
    448       precision, and recall metrics. If None, defaults to [0.5].
    449 
    450   Returns:
    451     Dictionary mapping metrics string names to metrics functions.
    452   """
    453   metrics = {}
    454   metrics[_MetricKeys.PREDICTION_MEAN] = _predictions_streaming_mean
    455   metrics[_MetricKeys.TARGET_MEAN] = _labels_streaming_mean
    456   # Also include the streaming mean of the label as an accuracy baseline, as
    457   # a reminder to users.
    458   metrics[_MetricKeys.ACCURACY_BASELINE] = _labels_streaming_mean
    459 
    460   metrics[_MetricKeys.AUC] = _streaming_auc
    461 
    462   for threshold in thresholds:
    463     metrics[_MetricKeys.ACCURACY_MEAN %
    464             threshold] = _accuracy_at_threshold(threshold)
    465     # Precision for positive examples.
    466     metrics[_MetricKeys.PRECISION_MEAN % threshold] = _streaming_at_threshold(
    467         metric_ops.streaming_precision_at_thresholds, threshold)
    468     # Recall for positive examples.
    469     metrics[_MetricKeys.RECALL_MEAN % threshold] = _streaming_at_threshold(
    470         metric_ops.streaming_recall_at_thresholds, threshold)
    471 
    472   return metrics
    473 
    474 
    475 def _float_weights_or_none(weights):
    476   if weights is None:
    477     return None
    478   return math_ops.to_float(weights)
    479 
    480 
    481 def _labels_streaming_mean(unused_predictions, labels, weights=None):
    482   return metric_ops.streaming_mean(labels, weights=weights)
    483 
    484 
    485 def _predictions_streaming_mean(predictions, unused_labels, weights=None):
    486   return metric_ops.streaming_mean(predictions, weights=weights)
    487 
    488 
    489 def _streaming_auc(predictions, labels, weights=None):
    490   return metric_ops.streaming_auc(
    491       predictions, labels, weights=_float_weights_or_none(weights))
    492 
    493 
    494 def _accuracy_at_threshold(threshold):
    495 
    496   def _accuracy_metric(predictions, labels, weights=None):
    497     threshold_predictions = math_ops.to_float(
    498         math_ops.greater_equal(predictions, threshold))
    499     return metric_ops.streaming_accuracy(
    500         predictions=threshold_predictions, labels=labels, weights=weights)
    501 
    502   return _accuracy_metric
    503 
    504 
    505 def _streaming_at_threshold(streaming_metrics_fn, threshold):
    506 
    507   def _streaming_metrics(predictions, labels, weights=None):
    508     precision_tensor, update_op = streaming_metrics_fn(
    509         predictions,
    510         labels=labels,
    511         thresholds=[threshold],
    512         weights=_float_weights_or_none(weights))
    513     return array_ops.squeeze(precision_tensor), update_op
    514 
    515   return _streaming_metrics
    516 
    517 
    518 class _MetricKeys(object):
    519   AUC = "auc"
    520   PREDICTION_MEAN = "labels/prediction_mean"
    521   TARGET_MEAN = "labels/actual_target_mean"
    522   ACCURACY_BASELINE = "accuracy/baseline_target_mean"
    523   ACCURACY_MEAN = "accuracy/threshold_%f_mean"
    524   PRECISION_MEAN = "precision/positive_threshold_%f_mean"
    525   RECALL_MEAN = "recall/positive_threshold_%f_mean"
    526