Home | History | Annotate | Download | only in keras
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Keras loss functions."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import shutil
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.keras._impl import keras
     27 from tensorflow.python.platform import test
     28 
     29 try:
     30   import h5py  # pylint:disable=g-import-not-at-top
     31 except ImportError:
     32   h5py = None
     33 
     34 ALL_LOSSES = [keras.losses.mean_squared_error,
     35               keras.losses.mean_absolute_error,
     36               keras.losses.mean_absolute_percentage_error,
     37               keras.losses.mean_squared_logarithmic_error,
     38               keras.losses.squared_hinge,
     39               keras.losses.hinge,
     40               keras.losses.categorical_crossentropy,
     41               keras.losses.binary_crossentropy,
     42               keras.losses.kullback_leibler_divergence,
     43               keras.losses.poisson,
     44               keras.losses.cosine_proximity,
     45               keras.losses.logcosh,
     46               keras.losses.categorical_hinge]
     47 
     48 
     49 class _MSEMAELoss(object):
     50   """Loss function with internal state, for testing serialization code."""
     51 
     52   def __init__(self, mse_fraction):
     53     self.mse_fraction = mse_fraction
     54 
     55   def __call__(self, y_true, y_pred):
     56     return (self.mse_fraction * keras.losses.mse(y_true, y_pred) +
     57             (1 - self.mse_fraction) * keras.losses.mae(y_true, y_pred))
     58 
     59   def get_config(self):
     60     return {'mse_fraction': self.mse_fraction}
     61 
     62 
     63 class KerasLossesTest(test.TestCase):
     64 
     65   def test_objective_shapes_3d(self):
     66     with self.test_session():
     67       y_a = keras.backend.variable(np.random.random((5, 6, 7)))
     68       y_b = keras.backend.variable(np.random.random((5, 6, 7)))
     69       for obj in ALL_LOSSES:
     70         objective_output = obj(y_a, y_b)
     71         self.assertListEqual(objective_output.get_shape().as_list(), [5, 6])
     72 
     73   def test_objective_shapes_2d(self):
     74     with self.test_session():
     75       y_a = keras.backend.variable(np.random.random((6, 7)))
     76       y_b = keras.backend.variable(np.random.random((6, 7)))
     77       for obj in ALL_LOSSES:
     78         objective_output = obj(y_a, y_b)
     79         self.assertListEqual(objective_output.get_shape().as_list(), [6,])
     80 
     81   def test_cce_one_hot(self):
     82     with self.test_session():
     83       y_a = keras.backend.variable(np.random.randint(0, 7, (5, 6)))
     84       y_b = keras.backend.variable(np.random.random((5, 6, 7)))
     85       objective_output = keras.losses.sparse_categorical_crossentropy(y_a, y_b)
     86       assert keras.backend.eval(objective_output).shape == (5, 6)
     87 
     88       y_a = keras.backend.variable(np.random.randint(0, 7, (6,)))
     89       y_b = keras.backend.variable(np.random.random((6, 7)))
     90       objective_output = keras.losses.sparse_categorical_crossentropy(y_a, y_b)
     91       assert keras.backend.eval(objective_output).shape == (6,)
     92 
     93   def test_serialization(self):
     94     fn = keras.losses.get('mse')
     95     config = keras.losses.serialize(fn)
     96     new_fn = keras.losses.deserialize(config)
     97     self.assertEqual(fn, new_fn)
     98 
     99   def test_categorical_hinge(self):
    100     y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1],
    101                                               [0.1, 0.2, 0.7]]))
    102     y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]]))
    103     expected_loss = ((0.3 - 0.2 + 1) + (0.7 - 0.1 + 1)) / 2.0
    104     loss = keras.backend.eval(keras.losses.categorical_hinge(y_true, y_pred))
    105     self.assertAllClose(expected_loss, np.mean(loss))
    106 
    107   def test_serializing_loss_class(self):
    108     orig_loss_class = _MSEMAELoss(0.3)
    109     with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
    110       serialized = keras.losses.serialize(orig_loss_class)
    111 
    112     with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
    113       deserialized = keras.losses.deserialize(serialized)
    114     assert isinstance(deserialized, _MSEMAELoss)
    115     assert deserialized.mse_fraction == 0.3
    116 
    117   def test_serializing_model_with_loss_class(self):
    118     tmpdir = self.get_temp_dir()
    119     self.addCleanup(shutil.rmtree, tmpdir)
    120     model_filename = os.path.join(tmpdir, 'custom_loss.h5')
    121 
    122     with self.test_session():
    123       with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
    124         loss = _MSEMAELoss(0.3)
    125         inputs = keras.layers.Input((2,))
    126         outputs = keras.layers.Dense(1, name='model_output')(inputs)
    127         model = keras.models.Model(inputs, outputs)
    128         model.compile(optimizer='sgd', loss={'model_output': loss})
    129         model.fit(np.random.rand(256, 2), np.random.rand(256, 1))
    130 
    131         if h5py is None:
    132           return
    133 
    134         model.save(model_filename)
    135 
    136       with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}):
    137         loaded_model = keras.models.load_model(model_filename)
    138         loaded_model.predict(np.random.rand(128, 2))
    139 
    140 
    141 if __name__ == '__main__':
    142   test.main()
    143