1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for third_party.tensorflow.contrib.kernel_methods.python.losses.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import numpy as np 22 23 from tensorflow.contrib.kernel_methods.python import losses 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import errors 27 from tensorflow.python.ops import array_ops 28 from tensorflow.python.platform import test 29 30 31 class SparseMulticlassHingeLossTest(test.TestCase): 32 33 def testInvalidLogitsShape(self): 34 """An error is raised when logits have invalid shape.""" 35 with self.cached_session(): 36 logits = constant_op.constant([-1.0, 2.1], shape=(2,)) 37 labels = constant_op.constant([0, 1]) 38 with self.assertRaises(ValueError): 39 _ = losses.sparse_multiclass_hinge_loss(labels, logits) 40 41 def testInvalidLabelsShape(self): 42 """An error is raised when labels have invalid shape.""" 43 with self.cached_session(): 44 logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) 45 labels = constant_op.constant([1, 0], shape=(1, 1, 2)) 46 with self.assertRaises(ValueError): 47 _ = losses.sparse_multiclass_hinge_loss(labels, logits) 48 49 def testInvalidWeightsShape(self): 50 """An error is raised when weights have invalid shape.""" 51 with self.cached_session(): 52 logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) 53 labels = constant_op.constant([1, 0], shape=(2,)) 54 weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1)) 55 with self.assertRaises(ValueError): 56 _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights) 57 58 def testInvalidLabelsDtype(self): 59 """An error is raised when labels have invalid shape.""" 60 with self.cached_session(): 61 logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) 62 labels = constant_op.constant([1, 0], dtype=dtypes.float32) 63 with self.assertRaises(ValueError): 64 _ = losses.sparse_multiclass_hinge_loss(labels, logits) 65 66 def testNoneWeightRaisesValueError(self): 67 """An error is raised when weights are None.""" 68 with self.cached_session(): 69 logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) 70 labels = constant_op.constant([1, 0]) 71 with self.assertRaises(ValueError): 72 _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights=None) 73 74 def testInconsistentLabelsAndWeightsShapesSameRank(self): 75 """Error raised when weights and labels have same ranks, different sizes.""" 76 with self.cached_session(): 77 logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1)) 78 labels = constant_op.constant([1, 0, 2], shape=(3, 1)) 79 weights = constant_op.constant([1.1, 2.0], shape=(2, 1)) 80 with self.assertRaises(ValueError): 81 _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights) 82 83 def testInconsistentLabelsAndWeightsShapesDifferentRank(self): 84 """Error raised when weights and labels have different ranks and sizes.""" 85 with self.cached_session(): 86 logits = constant_op.constant([-1.0, 2.1], shape=(2, 1)) 87 labels = constant_op.constant([1, 0], shape=(2, 1)) 88 weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,)) 89 with self.assertRaises(ValueError): 90 _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights) 91 92 def testOutOfRangeLabels(self): 93 """An error is raised when labels are not in [0, num_classes).""" 94 with self.cached_session(): 95 logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], 96 [0.5, 1.8, -1.0]]) 97 labels = constant_op.constant([1, 0, 4]) 98 loss = losses.sparse_multiclass_hinge_loss(labels, logits) 99 with self.assertRaises(errors.InvalidArgumentError): 100 loss.eval() 101 102 def testZeroLossInt32Labels(self): 103 """Loss is 0 if true class logits sufficiently higher than other classes.""" 104 with self.cached_session(): 105 logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], 106 [0.5, 1.8, -1.0]]) 107 labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32) 108 loss = losses.sparse_multiclass_hinge_loss(labels, logits) 109 self.assertAlmostEqual(loss.eval(), 0.0, 3) 110 111 def testZeroLossInt64Labels(self): 112 """Loss is 0 if true class logits sufficiently higher than other classes.""" 113 with self.cached_session(): 114 logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0], 115 [-0.5, 0.8, -1.0]]) 116 labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64) 117 loss = losses.sparse_multiclass_hinge_loss(labels, logits) 118 self.assertAlmostEqual(loss.eval(), 0.0, 3) 119 120 def testUnknownShape(self): 121 """Result keeps same with `testZeroLossInt32Labels`""" 122 logits_np = np.array([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], [0.5, 1.8, -1.0]]) 123 labels_np = np.array([0, 2, 1], dtype=np.int32) 124 125 logits_shapes = [ 126 [3, 3], # batch_size, num_classes 127 [None, 3], 128 [3, None], 129 [None, None] 130 ] 131 132 for batch_size, num_classes in logits_shapes: 133 with self.cached_session(): 134 logits = array_ops.placeholder( 135 dtypes.float32, shape=(batch_size, num_classes)) 136 labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,)) 137 loss = losses.sparse_multiclass_hinge_loss(labels, logits) 138 result = loss.eval(feed_dict={logits: logits_np, labels: labels_np}) 139 self.assertAlmostEqual(result, 0.0, 3) 140 141 def testCorrectPredictionsSomeClassesInsideMargin(self): 142 """Loss is > 0 even if true class logits are higher than other classes.""" 143 with self.cached_session(): 144 logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0], 145 [1.5, 1.8, -1.0]]) 146 labels = constant_op.constant([0, 2, 1]) 147 loss = losses.sparse_multiclass_hinge_loss(labels, logits) 148 # The first and third samples incur some loss (0.6 and 0.7 respectively). 149 self.assertAlmostEqual(loss.eval(), 0.4333, 3) 150 151 def testIncorrectPredictions(self): 152 """Loss is >0 when an incorrect class has higher logits than true class.""" 153 with self.cached_session(): 154 logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0], 155 [0.5, -1.8, 2.0]]) 156 labels = constant_op.constant([1, 0, 2]) 157 loss = losses.sparse_multiclass_hinge_loss(labels, logits) 158 # The first examples incurs a high loss (3.2) since the logits of an 159 # incorrect class (0) are higher than the logits of the ground truth. The 160 # second example also incures a (smaller) loss (0.4). 161 self.assertAlmostEqual(loss.eval(), 1.2, 3) 162 163 def testIncorrectPredictionsColumnLabels(self): 164 """Same as above but labels is a rank-2 tensor.""" 165 with self.cached_session(): 166 logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], 167 [0.2, -1.8, 4.0]]) 168 labels = constant_op.constant([1, 0, 2], shape=(3, 1)) 169 loss = losses.sparse_multiclass_hinge_loss(labels, logits) 170 # The first examples incurs a high loss (3.0) since the logits of an 171 # incorrect class (0) are higher than the logits of the ground truth. The 172 # second example also incures a (smaller) loss (0.3). 173 self.assertAlmostEqual(loss.eval(), 1.1, 3) 174 175 def testIncorrectPredictionsZeroWeights(self): 176 """Loss is 0 when all weights are missing even if predictions are wrong.""" 177 with self.cached_session(): 178 logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], 179 [0.2, -1.8, 4.0]]) 180 labels = constant_op.constant([1, 0, 2], shape=(3, 1)) 181 weights = constant_op.constant([0.0, 0.0, 0.0], shape=(3, 1)) 182 loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) 183 # No overall loss since all weights are 0. 184 self.assertAlmostEqual(loss.eval(), 0.0, 3) 185 186 def testNonZeroLossWithPythonScalarWeights(self): 187 """Weighted loss is correctly computed when weights is a python scalar.""" 188 with self.cached_session(): 189 logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], 190 [0.2, -1.8, 4.0]]) 191 labels = constant_op.constant([1, 0, 2], shape=(3, 1)) 192 weights = 10.0 193 loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) 194 self.assertAlmostEqual(loss.eval(), 11.0, 3) 195 196 def testNonZeroLossWithScalarTensorWeights(self): 197 """Weighted loss is correctly computed when weights is a rank-0 tensor.""" 198 with self.cached_session(): 199 logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], 200 [0.2, -1.8, 4.0]]) 201 labels = constant_op.constant([1, 0, 2], shape=(3, 1)) 202 weights = constant_op.constant(5.0) 203 loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) 204 self.assertAlmostEqual(loss.eval(), 5.5, 3) 205 206 def testNonZeroLossWith1DTensorWeightsColumnLabels(self): 207 """Weighted loss is correctly computed when weights is a rank-0 tensor.""" 208 with self.cached_session(): 209 logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], 210 [0.2, -1.8, 4.0]]) 211 labels = constant_op.constant([1, 0, 2], shape=(3, 1)) 212 weights = constant_op.constant([1.0, 0.5, 2.0], shape=(3,)) 213 loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) 214 # The overall loss is 1/3 *(3.0*1.0 + 0.5*0.3+ 2.0*0.0) = 1.05 215 self.assertAlmostEqual(loss.eval(), 1.05, 3) 216 217 def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self): 218 """Weighted loss is correctly computed when weights is a rank-0 tensor.""" 219 with self.cached_session(): 220 logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0], 221 [0.2, -1.8, 4.0], [1.6, 1.8, -4.0]]) 222 labels = constant_op.constant([1, 0, 2, 1]) 223 weights = constant_op.constant([[1.0], [0.0], [2.0], [4.0]]) 224 loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights) 225 # The overall loss is 1/3 *(3.0*1.0 + 0.0*0.3+ 2.0*0.0 + 4.0*0.8) = 6.2/3. 226 self.assertAlmostEqual(loss.eval(), 2.06666, 3) 227 228 229 if __name__ == '__main__': 230 test.main() 231