Home | History | Annotate | Download | only in python
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for third_party.tensorflow.contrib.kernel_methods.python.losses."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import numpy as np
     22 
     23 from tensorflow.contrib.kernel_methods.python import losses
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import errors
     27 from tensorflow.python.ops import array_ops
     28 from tensorflow.python.platform import test
     29 
     30 
     31 class SparseMulticlassHingeLossTest(test.TestCase):
     32 
     33   def testInvalidLogitsShape(self):
     34     """An error is raised when logits have invalid shape."""
     35     with self.cached_session():
     36       logits = constant_op.constant([-1.0, 2.1], shape=(2,))
     37       labels = constant_op.constant([0, 1])
     38       with self.assertRaises(ValueError):
     39         _ = losses.sparse_multiclass_hinge_loss(labels, logits)
     40 
     41   def testInvalidLabelsShape(self):
     42     """An error is raised when labels have invalid shape."""
     43     with self.cached_session():
     44       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
     45       labels = constant_op.constant([1, 0], shape=(1, 1, 2))
     46       with self.assertRaises(ValueError):
     47         _ = losses.sparse_multiclass_hinge_loss(labels, logits)
     48 
     49   def testInvalidWeightsShape(self):
     50     """An error is raised when weights have invalid shape."""
     51     with self.cached_session():
     52       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
     53       labels = constant_op.constant([1, 0], shape=(2,))
     54       weights = constant_op.constant([1.5, 0.2], shape=(2, 1, 1))
     55       with self.assertRaises(ValueError):
     56         _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
     57 
     58   def testInvalidLabelsDtype(self):
     59     """An error is raised when labels have invalid shape."""
     60     with self.cached_session():
     61       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
     62       labels = constant_op.constant([1, 0], dtype=dtypes.float32)
     63       with self.assertRaises(ValueError):
     64         _ = losses.sparse_multiclass_hinge_loss(labels, logits)
     65 
     66   def testNoneWeightRaisesValueError(self):
     67     """An error is raised when weights are None."""
     68     with self.cached_session():
     69       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
     70       labels = constant_op.constant([1, 0])
     71       with self.assertRaises(ValueError):
     72         _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights=None)
     73 
     74   def testInconsistentLabelsAndWeightsShapesSameRank(self):
     75     """Error raised when weights and labels have same ranks, different sizes."""
     76     with self.cached_session():
     77       logits = constant_op.constant([-1.0, 2.1, 4.1], shape=(3, 1))
     78       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
     79       weights = constant_op.constant([1.1, 2.0], shape=(2, 1))
     80       with self.assertRaises(ValueError):
     81         _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
     82 
     83   def testInconsistentLabelsAndWeightsShapesDifferentRank(self):
     84     """Error raised when weights and labels have different ranks and sizes."""
     85     with self.cached_session():
     86       logits = constant_op.constant([-1.0, 2.1], shape=(2, 1))
     87       labels = constant_op.constant([1, 0], shape=(2, 1))
     88       weights = constant_op.constant([1.1, 2.0, 2.8], shape=(3,))
     89       with self.assertRaises(ValueError):
     90         _ = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
     91 
     92   def testOutOfRangeLabels(self):
     93     """An error is raised when labels are not in [0, num_classes)."""
     94     with self.cached_session():
     95       logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
     96                                      [0.5, 1.8, -1.0]])
     97       labels = constant_op.constant([1, 0, 4])
     98       loss = losses.sparse_multiclass_hinge_loss(labels, logits)
     99       with self.assertRaises(errors.InvalidArgumentError):
    100         loss.eval()
    101 
    102   def testZeroLossInt32Labels(self):
    103     """Loss is 0 if true class logits sufficiently higher than other classes."""
    104     with self.cached_session():
    105       logits = constant_op.constant([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0],
    106                                      [0.5, 1.8, -1.0]])
    107       labels = constant_op.constant([0, 2, 1], dtype=dtypes.int32)
    108       loss = losses.sparse_multiclass_hinge_loss(labels, logits)
    109       self.assertAlmostEqual(loss.eval(), 0.0, 3)
    110 
    111   def testZeroLossInt64Labels(self):
    112     """Loss is 0 if true class logits sufficiently higher than other classes."""
    113     with self.cached_session():
    114       logits = constant_op.constant([[2.1, -0.4, -1.0], [1.4, 2.8, 4.0],
    115                                      [-0.5, 0.8, -1.0]])
    116       labels = constant_op.constant([0, 2, 1], dtype=dtypes.int64)
    117       loss = losses.sparse_multiclass_hinge_loss(labels, logits)
    118       self.assertAlmostEqual(loss.eval(), 0.0, 3)
    119 
    120   def testUnknownShape(self):
    121     """Result keeps same with `testZeroLossInt32Labels`"""
    122     logits_np = np.array([[1.2, -1.4, -1.0], [1.4, 1.8, 4.0], [0.5, 1.8, -1.0]])
    123     labels_np = np.array([0, 2, 1], dtype=np.int32)
    124 
    125     logits_shapes = [
    126         [3, 3],  # batch_size, num_classes
    127         [None, 3],
    128         [3, None],
    129         [None, None]
    130     ]
    131 
    132     for batch_size, num_classes in logits_shapes:
    133       with self.cached_session():
    134         logits = array_ops.placeholder(
    135             dtypes.float32, shape=(batch_size, num_classes))
    136         labels = array_ops.placeholder(dtypes.int32, shape=(batch_size,))
    137         loss = losses.sparse_multiclass_hinge_loss(labels, logits)
    138         result = loss.eval(feed_dict={logits: logits_np, labels: labels_np})
    139         self.assertAlmostEqual(result, 0.0, 3)
    140 
    141   def testCorrectPredictionsSomeClassesInsideMargin(self):
    142     """Loss is > 0 even if true class logits are higher than other classes."""
    143     with self.cached_session():
    144       logits = constant_op.constant([[1.2, -1.4, 0.8], [1.4, 1.8, 4.0],
    145                                      [1.5, 1.8, -1.0]])
    146       labels = constant_op.constant([0, 2, 1])
    147       loss = losses.sparse_multiclass_hinge_loss(labels, logits)
    148       # The first and third samples incur some loss (0.6 and 0.7 respectively).
    149       self.assertAlmostEqual(loss.eval(), 0.4333, 3)
    150 
    151   def testIncorrectPredictions(self):
    152     """Loss is >0 when an incorrect class has higher logits than true class."""
    153     with self.cached_session():
    154       logits = constant_op.constant([[2.6, 0.4, 0.8], [1.4, 0.8, -1.0],
    155                                      [0.5, -1.8, 2.0]])
    156       labels = constant_op.constant([1, 0, 2])
    157       loss = losses.sparse_multiclass_hinge_loss(labels, logits)
    158       # The first examples incurs a high loss (3.2) since the logits of an
    159       # incorrect class (0) are higher than the logits of the ground truth. The
    160       # second example also incures a (smaller) loss (0.4).
    161       self.assertAlmostEqual(loss.eval(), 1.2, 3)
    162 
    163   def testIncorrectPredictionsColumnLabels(self):
    164     """Same as above but labels is a rank-2 tensor."""
    165     with self.cached_session():
    166       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
    167                                      [0.2, -1.8, 4.0]])
    168       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
    169       loss = losses.sparse_multiclass_hinge_loss(labels, logits)
    170       # The first examples incurs a high loss (3.0) since the logits of an
    171       # incorrect class (0) are higher than the logits of the ground truth. The
    172       # second example also incures a (smaller) loss (0.3).
    173       self.assertAlmostEqual(loss.eval(), 1.1, 3)
    174 
    175   def testIncorrectPredictionsZeroWeights(self):
    176     """Loss is 0 when all weights are missing even if predictions are wrong."""
    177     with self.cached_session():
    178       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
    179                                      [0.2, -1.8, 4.0]])
    180       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
    181       weights = constant_op.constant([0.0, 0.0, 0.0], shape=(3, 1))
    182       loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
    183       # No overall loss since all weights are 0.
    184       self.assertAlmostEqual(loss.eval(), 0.0, 3)
    185 
    186   def testNonZeroLossWithPythonScalarWeights(self):
    187     """Weighted loss is correctly computed when weights is a python scalar."""
    188     with self.cached_session():
    189       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
    190                                      [0.2, -1.8, 4.0]])
    191       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
    192       weights = 10.0
    193       loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
    194       self.assertAlmostEqual(loss.eval(), 11.0, 3)
    195 
    196   def testNonZeroLossWithScalarTensorWeights(self):
    197     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
    198     with self.cached_session():
    199       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
    200                                      [0.2, -1.8, 4.0]])
    201       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
    202       weights = constant_op.constant(5.0)
    203       loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
    204       self.assertAlmostEqual(loss.eval(), 5.5, 3)
    205 
    206   def testNonZeroLossWith1DTensorWeightsColumnLabels(self):
    207     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
    208     with self.cached_session():
    209       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
    210                                      [0.2, -1.8, 4.0]])
    211       labels = constant_op.constant([1, 0, 2], shape=(3, 1))
    212       weights = constant_op.constant([1.0, 0.5, 2.0], shape=(3,))
    213       loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
    214       # The overall loss is 1/3 *(3.0*1.0 + 0.5*0.3+ 2.0*0.0) = 1.05
    215       self.assertAlmostEqual(loss.eval(), 1.05, 3)
    216 
    217   def testNonZeroLossWith2DTensorWeights1DLabelsSomeWeightsMissing(self):
    218     """Weighted loss is correctly computed when weights is a rank-0 tensor."""
    219     with self.cached_session():
    220       logits = constant_op.constant([[1.6, -0.4, 0.8], [1.5, 0.8, -1.0],
    221                                      [0.2, -1.8, 4.0], [1.6, 1.8, -4.0]])
    222       labels = constant_op.constant([1, 0, 2, 1])
    223       weights = constant_op.constant([[1.0], [0.0], [2.0], [4.0]])
    224       loss = losses.sparse_multiclass_hinge_loss(labels, logits, weights)
    225       # The overall loss is 1/3 *(3.0*1.0 + 0.0*0.3+ 2.0*0.0 + 4.0*0.8) = 6.2/3.
    226       self.assertAlmostEqual(loss.eval(), 2.06666, 3)
    227 
    228 
    229 if __name__ == '__main__':
    230   test.main()
    231