1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Implementation of kernel-methods-related loss operations.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 from tensorflow.python.framework import dtypes 22 from tensorflow.python.framework import ops 23 from tensorflow.python.ops import array_ops 24 from tensorflow.python.ops import check_ops 25 from tensorflow.python.ops import math_ops 26 from tensorflow.python.ops import nn_ops 27 from tensorflow.python.ops.losses import losses 28 29 30 def sparse_multiclass_hinge_loss( 31 labels, 32 logits, 33 weights=1.0, 34 scope=None, 35 loss_collection=ops.GraphKeys.LOSSES, 36 reduction=losses.Reduction.SUM_BY_NONZERO_WEIGHTS): 37 """Adds Ops for computing the multiclass hinge loss. 38 39 The implementation is based on the following paper: 40 On the Algorithmic Implementation of Multiclass Kernel-based Vector Machines 41 by Crammer and Singer. 42 link: http://jmlr.csail.mit.edu/papers/volume2/crammer01a/crammer01a.pdf 43 44 This is a generalization of standard (binary) hinge loss. For a given instance 45 with correct label c*, the loss is given by: 46 loss = max_{c != c*} logits_c - logits_{c*} + 1. 47 or equivalently 48 loss = max_c { logits_c - logits_{c*} + I_{c != c*} } 49 where I_{c != c*} = 1 if c != c* and 0 otherwise. 50 51 Args: 52 labels: `Tensor` of shape [batch_size] or [batch_size, 1]. Corresponds to 53 the ground truth. Each entry must be an index in `[0, num_classes)`. 54 logits: `Tensor` of shape [batch_size, num_classes] corresponding to the 55 unscaled logits. Its dtype should be either `float32` or `float64`. 56 weights: Optional (python) scalar or `Tensor`. If a non-scalar `Tensor`, its 57 rank should be either 1 ([batch_size]) or 2 ([batch_size, 1]). 58 scope: The scope for the operations performed in computing the loss. 59 loss_collection: collection to which the loss will be added. 60 reduction: Type of reduction to apply to loss. 61 62 Returns: 63 Weighted loss float `Tensor`. If `reduction` is `NONE`, this has the same 64 shape as `labels`; otherwise, it is a scalar. 65 66 Raises: 67 ValueError: If `logits`, `labels` or `weights` have invalid or inconsistent 68 shapes. 69 ValueError: If `labels` tensor has invalid dtype. 70 """ 71 72 with ops.name_scope(scope, 'sparse_multiclass_hinge_loss', (logits, 73 labels)) as scope: 74 75 # Check logits Tensor has valid rank. 76 logits_rank = logits.get_shape().ndims 77 if logits_rank != 2: 78 raise ValueError( 79 'logits should have rank 2 ([batch_size, num_classes]). Given rank is' 80 ' {}'.format(logits_rank)) 81 logits_shape = array_ops.shape(logits) 82 batch_size, num_classes = logits_shape[0], logits_shape[1] 83 logits = math_ops.to_float(logits) 84 85 # Check labels have valid type. 86 if labels.dtype != dtypes.int32 and labels.dtype != dtypes.int64: 87 raise ValueError( 88 'Invalid dtype for labels: {}. Acceptable dtypes: int32 and int64'. 89 format(labels.dtype)) 90 91 # Check labels and weights have valid ranks and are consistent. 92 labels_rank = labels.get_shape().ndims 93 if labels_rank not in [1, 2]: 94 raise ValueError( 95 'labels should have rank 1 ([batch_size]) or 2 ([batch_size, 1]). ' 96 'Given rank is {}'.format(labels_rank)) 97 with ops.control_dependencies([ 98 check_ops.assert_less(labels, math_ops.cast(num_classes, labels.dtype)) 99 ]): 100 labels = array_ops.reshape(labels, shape=[-1]) 101 102 weights = ops.convert_to_tensor(weights) 103 weights_rank = weights.get_shape().ndims 104 if weights_rank not in [0, 1, 2]: 105 raise ValueError( 106 'non-scalar weights should have rank 1 ([batch_size]) or 2 ' 107 '([batch_size, 1]). Given rank is {}'.format(labels_rank)) 108 109 if weights_rank > 0: 110 weights = array_ops.reshape(weights, shape=[-1]) 111 # Check weights and labels have the same number of elements. 112 weights.get_shape().assert_is_compatible_with(labels.get_shape()) 113 114 # Compute the logits tensor corresponding to the correct class per instance. 115 example_indices = array_ops.reshape( 116 math_ops.range(batch_size), shape=[batch_size, 1]) 117 indices = array_ops.concat( 118 [ 119 example_indices, 120 array_ops.reshape( 121 math_ops.cast(labels, example_indices.dtype), 122 shape=[batch_size, 1]) 123 ], 124 axis=1) 125 label_logits = array_ops.reshape( 126 array_ops.gather_nd(params=logits, indices=indices), 127 shape=[batch_size, 1]) 128 129 one_cold_labels = array_ops.one_hot( 130 indices=labels, depth=num_classes, on_value=0.0, off_value=1.0) 131 margin = logits - label_logits + one_cold_labels 132 margin = nn_ops.relu(margin) 133 loss = math_ops.reduce_max(margin, axis=1) 134 return losses.compute_weighted_loss( 135 loss, weights, scope, loss_collection, reduction=reduction) 136