Home | History | Annotate | Download | only in losses
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Loss operations for use in neural networks.
     16 
     17 Note: All the losses are added to the `GraphKeys.LOSSES` collection.
     18 """
     19 
     20 from __future__ import absolute_import
     21 from __future__ import division
     22 from __future__ import print_function
     23 
     24 from tensorflow.contrib.framework.python.ops import add_arg_scope
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import array_ops
     27 from tensorflow.python.ops import math_ops
     28 from tensorflow.python.ops import nn
     29 from tensorflow.python.ops import nn_ops
     30 from tensorflow.python.util.deprecation import deprecated
     31 from tensorflow.python.util.deprecation import deprecated_args
     32 
     33 __all__ = [
     34     "absolute_difference", "add_loss", "cosine_distance",
     35     "compute_weighted_loss", "get_losses", "get_regularization_losses",
     36     "get_total_loss", "hinge_loss", "log_loss", "mean_pairwise_squared_error",
     37     "mean_squared_error", "sigmoid_cross_entropy", "softmax_cross_entropy",
     38     "sparse_softmax_cross_entropy"
     39 ]
     40 
     41 
     42 def _scale_losses(losses, weights):
     43   """Computes the scaled loss.
     44 
     45   Args:
     46     losses: A `Tensor` of size [batch_size, d1, ... dN].
     47     weights: A `Tensor` of size [1], [batch_size] or [batch_size, d1, ... dN].
     48       The `losses` are reduced (tf.reduce_sum) until its dimension matches
     49       that of `weights` at which point the reduced `losses` are element-wise
     50       multiplied by `weights` and a final reduce_sum is computed on the result.
     51       Conceptually, this operation is equivalent to broadcasting (tiling)
     52       `weights` to be the same size as `losses`, performing an element-wise
     53       multiplication, and summing the result.
     54 
     55   Returns:
     56     A scalar tf.float32 `Tensor` whose value represents the sum of the scaled
     57       `losses`.
     58   """
     59   # First, compute the sum of the losses over all elements:
     60   start_index = max(0, weights.get_shape().ndims)
     61   reduction_indices = list(range(start_index, losses.get_shape().ndims))
     62   reduced_losses = math_ops.reduce_sum(
     63       losses, reduction_indices=reduction_indices)
     64   reduced_losses = math_ops.multiply(reduced_losses, weights)
     65   return math_ops.reduce_sum(reduced_losses)
     66 
     67 
     68 def _safe_div(numerator, denominator, name="value"):
     69   """Computes a safe divide which returns 0 if the denominator is zero.
     70 
     71   Note that the function contains an additional conditional check that is
     72   necessary for avoiding situations where the loss is zero causing NaNs to
     73   creep into the gradient computation.
     74 
     75   Args:
     76     numerator: An arbitrary `Tensor`.
     77     denominator: A `Tensor` whose shape matches `numerator` and whose values are
     78       assumed to be non-negative.
     79     name: An optional name for the returned op.
     80 
     81   Returns:
     82     The element-wise value of the numerator divided by the denominator.
     83   """
     84   return array_ops.where(
     85       math_ops.greater(denominator, 0),
     86       math_ops.div(numerator,
     87                    array_ops.where(
     88                        math_ops.equal(denominator, 0),
     89                        array_ops.ones_like(denominator), denominator)),
     90       array_ops.zeros_like(numerator),
     91       name=name)
     92 
     93 
     94 def _safe_mean(losses, num_present):
     95   """Computes a safe mean of the losses.
     96 
     97   Args:
     98     losses: A tensor whose elements contain individual loss measurements.
     99     num_present: The number of measurable losses in the tensor.
    100 
    101   Returns:
    102     A scalar representing the mean of the losses. If `num_present` is zero,
    103       then zero is returned.
    104   """
    105   total_loss = math_ops.reduce_sum(losses)
    106   return _safe_div(total_loss, num_present)
    107 
    108 
    109 @deprecated("2016-12-30", "Use tf.losses.compute_weighted_loss instead.")
    110 def compute_weighted_loss(losses, weights=1.0, scope=None):
    111   """Computes the weighted loss.
    112 
    113   Args:
    114     losses: A tensor of size [batch_size, d1, ... dN].
    115     weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
    116     scope: the scope for the operations performed in computing the loss.
    117 
    118   Returns:
    119     A scalar `Tensor` that returns the weighted loss.
    120 
    121   Raises:
    122     ValueError: If `weights` is `None` or the shape is not compatible with
    123       `losses`, or if the number of dimensions (rank) of either `losses` or
    124       `weights` is missing.
    125   """
    126   with ops.name_scope(scope, "weighted_loss", [losses, weights]):
    127     losses = ops.convert_to_tensor(losses)
    128     input_dtype = losses.dtype
    129     losses = math_ops.to_float(losses)
    130     weights = math_ops.to_float(ops.convert_to_tensor(weights))
    131 
    132     if losses.get_shape().ndims is None:
    133       raise ValueError("losses.get_shape().ndims cannot be None")
    134     weights_shape = weights.get_shape()
    135     if weights_shape.ndims is None:
    136       raise ValueError("weights.get_shape().ndims cannot be None")
    137 
    138     if weights_shape.ndims > 1 and weights_shape.dims[-1].is_compatible_with(1):
    139       weights = array_ops.squeeze(weights, [-1])
    140 
    141     total_loss = _scale_losses(losses, weights)
    142     num_present = _num_present(losses, weights)
    143     mean_loss = _safe_mean(total_loss, num_present)
    144     # convert the result back to the input type
    145     mean_loss = math_ops.cast(mean_loss, input_dtype)
    146     add_loss(mean_loss)
    147     return mean_loss
    148 
    149 
    150 def _num_present(losses, weights, per_batch=False):
    151   """Computes the number of elements in the loss function induced by `weights`.
    152 
    153   A given weights tensor induces different numbers of usable elements in the
    154   `losses` tensor. The `weights` tensor is broadcast across `losses` for all
    155   possible dimensions. For example, if `losses` is a tensor of dimension
    156   [4, 5, 6, 3] and `weights` is a tensor of size [4, 5], then `weights` is, in
    157   effect, tiled to match the size of `losses`. Following this effective tile,
    158   the total number of present elements is the number of non-zero weights.
    159 
    160   Args:
    161     losses: A tensor of size [batch_size, d1, ... dN].
    162     weights: A tensor of size [1] or [batch_size, d1, ... dK] where K < N.
    163     per_batch: Whether to return the number of elements per batch or as a sum
    164       total.
    165 
    166   Returns:
    167     The number of present (non-zero) elements in the losses tensor. If
    168       `per_batch` is True, the value is returned as a tensor of size
    169       [batch_size]. Otherwise, a single scalar tensor is returned.
    170   """
    171   # If weights is a scalar, its easy to compute:
    172   if weights.get_shape().ndims == 0:
    173     batch_size = array_ops.reshape(
    174         array_ops.slice(array_ops.shape(losses), [0], [1]), [])
    175     num_per_batch = math_ops.div(
    176         math_ops.to_float(array_ops.size(losses)),
    177         math_ops.to_float(batch_size))
    178     num_per_batch = array_ops.where(
    179         math_ops.equal(weights, 0), 0.0, num_per_batch)
    180     num_per_batch = math_ops.multiply(
    181         array_ops.ones(array_ops.reshape(batch_size, [1])), num_per_batch)
    182     return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
    183 
    184   # First, count the number of nonzero weights:
    185   if weights.get_shape().ndims >= 1:
    186     reduction_indices = list(range(1, weights.get_shape().ndims))
    187     num_nonzero_per_batch = math_ops.reduce_sum(
    188         math_ops.to_float(math_ops.not_equal(weights, 0)),
    189         reduction_indices=reduction_indices)
    190 
    191   # Next, determine the number of elements that weights would broadcast to:
    192   broadcast_dims = array_ops.slice(
    193       array_ops.shape(losses), [weights.get_shape().ndims], [-1])
    194   num_to_broadcast = math_ops.to_float(math_ops.reduce_prod(broadcast_dims))
    195 
    196   num_per_batch = math_ops.multiply(num_nonzero_per_batch, num_to_broadcast)
    197   return num_per_batch if per_batch else math_ops.reduce_sum(num_per_batch)
    198 
    199 
    200 @deprecated("2016-12-30", "Use tf.losses.add_loss instead.")
    201 @add_arg_scope
    202 def add_loss(loss, loss_collection=ops.GraphKeys.LOSSES):
    203   """Adds a externally defined loss to the collection of losses.
    204 
    205   Args:
    206     loss: A loss `Tensor`.
    207     loss_collection: Optional collection to add the loss to.
    208   """
    209   if loss_collection:
    210     ops.add_to_collection(loss_collection, loss)
    211 
    212 
    213 @deprecated("2016-12-30", "Use tf.losses.get_losses instead.")
    214 def get_losses(scope=None, loss_collection=ops.GraphKeys.LOSSES):
    215   """Gets the list of losses from the loss_collection.
    216 
    217   Args:
    218     scope: an optional scope for filtering the losses to return.
    219     loss_collection: Optional losses collection.
    220 
    221   Returns:
    222     a list of loss tensors.
    223   """
    224   return ops.get_collection(loss_collection, scope)
    225 
    226 
    227 @deprecated("2016-12-30", "Use tf.losses.get_regularization_losses instead.")
    228 def get_regularization_losses(scope=None):
    229   """Gets the regularization losses.
    230 
    231   Args:
    232     scope: an optional scope for filtering the losses to return.
    233 
    234   Returns:
    235     A list of regularization losses as Tensors.
    236   """
    237   return ops.get_collection(ops.GraphKeys.REGULARIZATION_LOSSES, scope)
    238 
    239 
    240 @deprecated("2016-12-30", "Use tf.losses.get_total_loss instead.")
    241 def get_total_loss(add_regularization_losses=True, name="total_loss"):
    242   """Returns a tensor whose value represents the total loss.
    243 
    244   Notice that the function adds the given losses to the regularization losses.
    245 
    246   Args:
    247     add_regularization_losses: A boolean indicating whether or not to use the
    248       regularization losses in the sum.
    249     name: The name of the returned tensor.
    250 
    251   Returns:
    252     A `Tensor` whose value represents the total loss.
    253 
    254   Raises:
    255     ValueError: if `losses` is not iterable.
    256   """
    257   losses = get_losses()
    258   if add_regularization_losses:
    259     losses += get_regularization_losses()
    260   return math_ops.add_n(losses, name=name)
    261 
    262 
    263 @deprecated("2016-12-30", "Use tf.losses.absolute_difference instead.")
    264 def absolute_difference(predictions, labels=None, weights=1.0, scope=None):
    265   """Adds an Absolute Difference loss to the training procedure.
    266 
    267   `weights` acts as a coefficient for the loss. If a scalar is provided, then
    268   the loss is simply scaled by the given value. If `weights` is a tensor of size
    269   [batch_size], then the total loss for each sample of the batch is rescaled
    270   by the corresponding element in the `weights` vector. If the shape of
    271   `weights` matches the shape of `predictions`, then the loss of each
    272   measurable element of `predictions` is scaled by the corresponding value of
    273   `weights`.
    274 
    275   Args:
    276     predictions: The predicted outputs.
    277     labels: The ground truth output tensor, same dimensions as 'predictions'.
    278     weights: Coefficients for the loss a scalar, a tensor of shape
    279       [batch_size] or a tensor whose shape matches `predictions`.
    280     scope: The scope for the operations performed in computing the loss.
    281 
    282   Returns:
    283     A scalar `Tensor` representing the loss value.
    284 
    285   Raises:
    286     ValueError: If the shape of `predictions` doesn't match that of `labels` or
    287       if the shape of `weights` is invalid.
    288   """
    289   with ops.name_scope(scope, "absolute_difference",
    290                       [predictions, labels, weights]) as scope:
    291     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    292     predictions = math_ops.to_float(predictions)
    293     labels = math_ops.to_float(labels)
    294     losses = math_ops.abs(math_ops.subtract(predictions, labels))
    295     return compute_weighted_loss(losses, weights, scope=scope)
    296 
    297 
    298 @deprecated("2016-12-30",
    299             "Use tf.losses.sigmoid_cross_entropy instead. Note that the order "
    300             "of the predictions and labels arguments has been changed.")
    301 def sigmoid_cross_entropy(logits,
    302                           multi_class_labels,
    303                           weights=1.0,
    304                           label_smoothing=0,
    305                           scope=None):
    306   """Creates a cross-entropy loss using tf.nn.sigmoid_cross_entropy_with_logits.
    307 
    308   `weights` acts as a coefficient for the loss. If a scalar is provided,
    309   then the loss is simply scaled by the given value. If `weights` is a
    310   tensor of size [`batch_size`], then the loss weights apply to each
    311   corresponding sample.
    312 
    313   If `label_smoothing` is nonzero, smooth the labels towards 1/2:
    314 
    315       new_multiclass_labels = multiclass_labels * (1 - label_smoothing)
    316                               + 0.5 * label_smoothing
    317 
    318   Args:
    319     logits: [batch_size, num_classes] logits outputs of the network .
    320     multi_class_labels: [batch_size, num_classes] labels in (0, 1).
    321     weights: Coefficients for the loss. The tensor must be a scalar, a tensor of
    322       shape [batch_size] or shape [batch_size, num_classes].
    323     label_smoothing: If greater than 0 then smooth the labels.
    324     scope: The scope for the operations performed in computing the loss.
    325 
    326   Returns:
    327     A scalar `Tensor` representing the loss value.
    328 
    329   Raises:
    330     ValueError: If the shape of `logits` doesn't match that of
    331       `multi_class_labels` or if the shape of `weights` is invalid, or if
    332       `weights` is None.
    333   """
    334   with ops.name_scope(scope, "sigmoid_cross_entropy_loss",
    335                       [logits, multi_class_labels, weights]) as scope:
    336     logits.get_shape().assert_is_compatible_with(multi_class_labels.get_shape())
    337 
    338     multi_class_labels = math_ops.cast(multi_class_labels, logits.dtype)
    339 
    340     if label_smoothing > 0:
    341       multi_class_labels = (
    342           multi_class_labels * (1 - label_smoothing) + 0.5 * label_smoothing)
    343 
    344     losses = nn.sigmoid_cross_entropy_with_logits(
    345         labels=multi_class_labels, logits=logits, name="xentropy")
    346     return compute_weighted_loss(losses, weights, scope=scope)
    347 
    348 
    349 @deprecated("2016-12-30",
    350             "Use tf.losses.softmax_cross_entropy instead. Note that the order "
    351             "of the logits and labels arguments has been changed.")
    352 def softmax_cross_entropy(logits,
    353                           onehot_labels,
    354                           weights=1.0,
    355                           label_smoothing=0,
    356                           scope=None):
    357   """Creates a cross-entropy loss using tf.nn.softmax_cross_entropy_with_logits.
    358 
    359   `weights` acts as a coefficient for the loss. If a scalar is provided,
    360   then the loss is simply scaled by the given value. If `weights` is a
    361   tensor of size [`batch_size`], then the loss weights apply to each
    362   corresponding sample.
    363 
    364   If `label_smoothing` is nonzero, smooth the labels towards 1/num_classes:
    365       new_onehot_labels = onehot_labels * (1 - label_smoothing)
    366                           + label_smoothing / num_classes
    367 
    368   Args:
    369     logits: [batch_size, num_classes] logits outputs of the network .
    370     onehot_labels: [batch_size, num_classes] one-hot-encoded labels.
    371     weights: Coefficients for the loss. The tensor must be a scalar or a tensor
    372       of shape [batch_size].
    373     label_smoothing: If greater than 0 then smooth the labels.
    374     scope: the scope for the operations performed in computing the loss.
    375 
    376   Returns:
    377     A scalar `Tensor` representing the mean loss value.
    378 
    379   Raises:
    380     ValueError: If the shape of `logits` doesn't match that of `onehot_labels`
    381       or if the shape of `weights` is invalid or if `weights` is None.
    382   """
    383   with ops.name_scope(scope, "softmax_cross_entropy_loss",
    384                       [logits, onehot_labels, weights]) as scope:
    385     logits.get_shape().assert_is_compatible_with(onehot_labels.get_shape())
    386 
    387     onehot_labels = math_ops.cast(onehot_labels, logits.dtype)
    388 
    389     if label_smoothing > 0:
    390       num_classes = math_ops.cast(
    391           array_ops.shape(onehot_labels)[1], logits.dtype)
    392       smooth_positives = 1.0 - label_smoothing
    393       smooth_negatives = label_smoothing / num_classes
    394       onehot_labels = onehot_labels * smooth_positives + smooth_negatives
    395 
    396     losses = nn.softmax_cross_entropy_with_logits(
    397         labels=onehot_labels, logits=logits, name="xentropy")
    398     return compute_weighted_loss(losses, weights, scope=scope)
    399 
    400 
    401 @deprecated("2016-12-30",
    402             "Use tf.losses.sparse_softmax_cross_entropy instead. Note that "
    403             "the order of the logits and labels arguments has been changed.")
    404 def sparse_softmax_cross_entropy(logits, labels, weights=1.0, scope=None):
    405   """Cross-entropy loss using `tf.nn.sparse_softmax_cross_entropy_with_logits`.
    406 
    407   `weights` acts as a coefficient for the loss. If a scalar is provided,
    408   then the loss is simply scaled by the given value. If `weights` is a
    409   tensor of size [`batch_size`], then the loss weights apply to each
    410   corresponding sample.
    411 
    412   Args:
    413     logits: [batch_size, num_classes] logits outputs of the network .
    414     labels: [batch_size, 1] or [batch_size] labels of dtype `int32` or `int64`
    415       in the range `[0, num_classes)`.
    416     weights: Coefficients for the loss. The tensor must be a scalar or a tensor
    417       of shape [batch_size] or [batch_size, 1].
    418     scope: the scope for the operations performed in computing the loss.
    419 
    420   Returns:
    421     A scalar `Tensor` representing the mean loss value.
    422 
    423   Raises:
    424     ValueError: If the shapes of `logits`, `labels`, and `weights` are
    425       incompatible, or if `weights` is None.
    426   """
    427   with ops.name_scope(scope, "sparse_softmax_cross_entropy_loss",
    428                       [logits, labels, weights]) as scope:
    429     labels = array_ops.reshape(labels, shape=[array_ops.shape(labels)[0]])
    430 
    431     losses = nn.sparse_softmax_cross_entropy_with_logits(
    432         labels=labels, logits=logits, name="xentropy")
    433     return compute_weighted_loss(losses, weights, scope=scope)
    434 
    435 
    436 @deprecated("2016-12-30",
    437             "Use tf.losses.log_loss instead. Note that the order of the "
    438             "predictions and labels arguments has been changed.")
    439 def log_loss(predictions, labels=None, weights=1.0, epsilon=1e-7, scope=None):
    440   """Adds a Log Loss term to the training procedure.
    441 
    442   `weights` acts as a coefficient for the loss. If a scalar is provided, then
    443   the loss is simply scaled by the given value. If `weights` is a tensor of size
    444   [batch_size], then the total loss for each sample of the batch is rescaled
    445   by the corresponding element in the `weights` vector. If the shape of
    446   `weights` matches the shape of `predictions`, then the loss of each
    447   measurable element of `predictions` is scaled by the corresponding value of
    448   `weights`.
    449 
    450   Args:
    451     predictions: The predicted outputs.
    452     labels: The ground truth output tensor, same dimensions as 'predictions'.
    453     weights: Coefficients for the loss a scalar, a tensor of shape
    454       [batch_size] or a tensor whose shape matches `predictions`.
    455     epsilon: A small increment to add to avoid taking a log of zero.
    456     scope: The scope for the operations performed in computing the loss.
    457 
    458   Returns:
    459     A scalar `Tensor` representing the loss value.
    460 
    461   Raises:
    462     ValueError: If the shape of `predictions` doesn't match that of `labels` or
    463       if the shape of `weights` is invalid.
    464   """
    465   with ops.name_scope(scope, "log_loss",
    466                       [predictions, labels, weights]) as scope:
    467     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    468     predictions = math_ops.to_float(predictions)
    469     labels = math_ops.to_float(labels)
    470     losses = -math_ops.multiply(
    471         labels, math_ops.log(predictions + epsilon)) - math_ops.multiply(
    472             (1 - labels), math_ops.log(1 - predictions + epsilon))
    473     return compute_weighted_loss(losses, weights, scope=scope)
    474 
    475 
    476 @deprecated("2016-12-30",
    477             "Use tf.losses.hinge_loss instead. Note that the order of the "
    478             "logits and labels arguments has been changed, and to stay "
    479             "unweighted, reduction=Reduction.NONE")
    480 def hinge_loss(logits, labels=None, scope=None):
    481   """Method that returns the loss tensor for hinge loss.
    482 
    483   Args:
    484     logits: The logits, a float tensor.
    485     labels: The ground truth output tensor. Its shape should match the shape of
    486       logits. The values of the tensor are expected to be 0.0 or 1.0.
    487     scope: The scope for the operations performed in computing the loss.
    488 
    489   Returns:
    490     An unweighted `Tensor` of same shape as `logits` and `labels` representing
    491     the
    492       loss values across the batch.
    493 
    494   Raises:
    495     ValueError: If the shapes of `logits` and `labels` don't match.
    496   """
    497   with ops.name_scope(scope, "hinge_loss", [logits, labels]) as scope:
    498     logits.get_shape().assert_is_compatible_with(labels.get_shape())
    499     # We first need to convert binary labels to -1/1 labels (as floats).
    500     labels = math_ops.to_float(labels)
    501     all_ones = array_ops.ones_like(labels)
    502     labels = math_ops.subtract(2 * labels, all_ones)
    503     return nn_ops.relu(
    504         math_ops.subtract(all_ones, math_ops.multiply(labels, logits)))
    505 
    506 
    507 @deprecated("2016-12-30", "Use tf.losses.mean_squared_error instead.")
    508 def mean_squared_error(predictions, labels=None, weights=1.0, scope=None):
    509   """Adds a Sum-of-Squares loss to the training procedure.
    510 
    511   `weights` acts as a coefficient for the loss. If a scalar is provided, then
    512   the loss is simply scaled by the given value. If `weights` is a tensor of size
    513   [batch_size], then the total loss for each sample of the batch is rescaled
    514   by the corresponding element in the `weights` vector. If the shape of
    515   `weights` matches the shape of `predictions`, then the loss of each
    516   measurable element of `predictions` is scaled by the corresponding value of
    517   `weights`.
    518 
    519   Args:
    520     predictions: The predicted outputs.
    521     labels: The ground truth output tensor, same dimensions as 'predictions'.
    522     weights: Coefficients for the loss a scalar, a tensor of shape
    523       [batch_size] or a tensor whose shape matches `predictions`.
    524     scope: The scope for the operations performed in computing the loss.
    525 
    526   Returns:
    527     A scalar `Tensor` representing the loss value.
    528 
    529   Raises:
    530     ValueError: If the shape of `predictions` doesn't match that of `labels` or
    531       if the shape of `weights` is invalid.
    532   """
    533   with ops.name_scope(scope, "mean_squared_error",
    534                       [predictions, labels, weights]) as scope:
    535     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    536     predictions = math_ops.to_float(predictions)
    537     labels = math_ops.to_float(labels)
    538     losses = math_ops.square(math_ops.subtract(predictions, labels))
    539     return compute_weighted_loss(losses, weights, scope=scope)
    540 
    541 
    542 @deprecated("2016-12-30",
    543             "Use tf.losses.mean_pairwise_squared_error instead. Note that the "
    544             "order of the predictions and labels arguments has been changed.")
    545 def mean_pairwise_squared_error(predictions,
    546                                 labels=None,
    547                                 weights=1.0,
    548                                 scope=None):
    549   """Adds a pairwise-errors-squared loss to the training procedure.
    550 
    551   Unlike `mean_squared_error`, which is a measure of the differences between
    552   corresponding elements of `predictions` and `labels`,
    553   `mean_pairwise_squared_error` is a measure of the differences between pairs of
    554   corresponding elements of `predictions` and `labels`.
    555 
    556   For example, if `labels`=[a, b, c] and `predictions`=[x, y, z], there are
    557   three pairs of differences are summed to compute the loss:
    558     loss = [ ((a-b) - (x-y)).^2 + ((a-c) - (x-z)).^2 + ((b-c) - (y-z)).^2 ] / 3
    559 
    560   Note that since the inputs are of size [batch_size, d0, ... dN], the
    561   corresponding pairs are computed within each batch sample but not across
    562   samples within a batch. For example, if `predictions` represents a batch of
    563   16 grayscale images of dimension [batch_size, 100, 200], then the set of pairs
    564   is drawn from each image, but not across images.
    565 
    566   `weights` acts as a coefficient for the loss. If a scalar is provided, then
    567   the loss is simply scaled by the given value. If `weights` is a tensor of size
    568   [batch_size], then the total loss for each sample of the batch is rescaled
    569   by the corresponding element in the `weights` vector.
    570 
    571   Args:
    572     predictions: The predicted outputs, a tensor of size [batch_size, d0, .. dN]
    573       where N+1 is the total number of dimensions in `predictions`.
    574     labels: The ground truth output tensor, whose shape must match the shape of
    575       the `predictions` tensor.
    576     weights: Coefficients for the loss a scalar, a tensor of shape [batch_size]
    577       or a tensor whose shape matches `predictions`.
    578     scope: The scope for the operations performed in computing the loss.
    579 
    580   Returns:
    581     A scalar `Tensor` representing the loss value.
    582 
    583   Raises:
    584     ValueError: If the shape of `predictions` doesn't match that of `labels` or
    585       if the shape of `weights` is invalid.
    586   """
    587   with ops.name_scope(scope, "mean_pairwise_squared_error",
    588                       [predictions, labels, weights]) as scope:
    589     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    590     predictions = math_ops.to_float(predictions)
    591     labels = math_ops.to_float(labels)
    592     weights = math_ops.to_float(ops.convert_to_tensor(weights))
    593 
    594     diffs = math_ops.subtract(predictions, labels)
    595 
    596     # Need to verify here since the function doesn't use compute_weighted_loss
    597     if diffs.get_shape().ndims is None:
    598       raise ValueError("diffs.get_shape().ndims cannot be None")
    599     if weights.get_shape().ndims is None:
    600       raise ValueError("weights.get_shape().ndims cannot be None")
    601 
    602     reduction_indices = list(range(1, diffs.get_shape().ndims))
    603 
    604     sum_squares_diff_per_batch = math_ops.reduce_sum(
    605         math_ops.square(diffs), reduction_indices=reduction_indices)
    606     num_present_per_batch = _num_present(diffs, weights, per_batch=True)
    607 
    608     term1 = 2.0 * _safe_div(sum_squares_diff_per_batch, num_present_per_batch)
    609 
    610     sum_diff = math_ops.reduce_sum(diffs, reduction_indices=reduction_indices)
    611     term2 = 2.0 * _safe_div(
    612         math_ops.square(sum_diff), math_ops.square(num_present_per_batch))
    613 
    614     loss = _scale_losses(term1 - term2, weights)
    615 
    616     mean_loss = array_ops.where(
    617         math_ops.reduce_sum(num_present_per_batch) > 0,
    618         loss,
    619         array_ops.zeros_like(loss),
    620         name="value")
    621     add_loss(mean_loss)
    622     return mean_loss
    623 
    624 
    625 @deprecated("2016-12-30", "Use tf.losses.cosine_distance instead.")
    626 @deprecated_args(None, "dim is deprecated, use axis instead", "dim")
    627 def cosine_distance(predictions,
    628                     labels=None,
    629                     axis=None,
    630                     weights=1.0,
    631                     scope=None,
    632                     dim=None):
    633   """Adds a cosine-distance loss to the training procedure.
    634 
    635   Note that the function assumes that `predictions` and `labels` are already
    636   unit-normalized.
    637 
    638   Args:
    639     predictions: An arbitrary matrix.
    640     labels: A `Tensor` whose shape matches 'predictions'
    641     axis: The dimension along which the cosine distance is computed.
    642     weights: Coefficients for the loss a scalar, a tensor of shape
    643       [batch_size] or a tensor whose shape matches `predictions`.
    644     scope: The scope for the operations performed in computing the loss.
    645     dim: The old (deprecated) name for `axis`.
    646 
    647   Returns:
    648     A scalar `Tensor` representing the loss value.
    649 
    650   Raises:
    651     ValueError: If `predictions` shape doesn't match `labels` shape, or
    652       `weights` is `None`.
    653   """
    654   if dim is not None:
    655     if axis is not None:
    656       raise ValueError("Cannot specify both 'axis' and 'dim'")
    657     axis = dim
    658   if axis is None and dim is None:
    659     raise ValueError("You must specify 'axis'.")
    660   with ops.name_scope(scope, "cosine_distance_loss",
    661                       [predictions, labels, weights]) as scope:
    662     predictions.get_shape().assert_is_compatible_with(labels.get_shape())
    663 
    664     predictions = math_ops.to_float(predictions)
    665     labels = math_ops.to_float(labels)
    666 
    667     radial_diffs = math_ops.multiply(predictions, labels)
    668     losses = 1 - math_ops.reduce_sum(
    669         radial_diffs, reduction_indices=[
    670             axis,
    671         ])
    672     return compute_weighted_loss(losses, weights, scope=scope)
    673