1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for Keras loss functions.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import os 22 import shutil 23 24 import numpy as np 25 26 from tensorflow.python.keras._impl import keras 27 from tensorflow.python.platform import test 28 29 try: 30 import h5py # pylint:disable=g-import-not-at-top 31 except ImportError: 32 h5py = None 33 34 ALL_LOSSES = [keras.losses.mean_squared_error, 35 keras.losses.mean_absolute_error, 36 keras.losses.mean_absolute_percentage_error, 37 keras.losses.mean_squared_logarithmic_error, 38 keras.losses.squared_hinge, 39 keras.losses.hinge, 40 keras.losses.categorical_crossentropy, 41 keras.losses.binary_crossentropy, 42 keras.losses.kullback_leibler_divergence, 43 keras.losses.poisson, 44 keras.losses.cosine_proximity, 45 keras.losses.logcosh, 46 keras.losses.categorical_hinge] 47 48 49 class _MSEMAELoss(object): 50 """Loss function with internal state, for testing serialization code.""" 51 52 def __init__(self, mse_fraction): 53 self.mse_fraction = mse_fraction 54 55 def __call__(self, y_true, y_pred): 56 return (self.mse_fraction * keras.losses.mse(y_true, y_pred) + 57 (1 - self.mse_fraction) * keras.losses.mae(y_true, y_pred)) 58 59 def get_config(self): 60 return {'mse_fraction': self.mse_fraction} 61 62 63 class KerasLossesTest(test.TestCase): 64 65 def test_objective_shapes_3d(self): 66 with self.test_session(): 67 y_a = keras.backend.variable(np.random.random((5, 6, 7))) 68 y_b = keras.backend.variable(np.random.random((5, 6, 7))) 69 for obj in ALL_LOSSES: 70 objective_output = obj(y_a, y_b) 71 self.assertListEqual(objective_output.get_shape().as_list(), [5, 6]) 72 73 def test_objective_shapes_2d(self): 74 with self.test_session(): 75 y_a = keras.backend.variable(np.random.random((6, 7))) 76 y_b = keras.backend.variable(np.random.random((6, 7))) 77 for obj in ALL_LOSSES: 78 objective_output = obj(y_a, y_b) 79 self.assertListEqual(objective_output.get_shape().as_list(), [6,]) 80 81 def test_cce_one_hot(self): 82 with self.test_session(): 83 y_a = keras.backend.variable(np.random.randint(0, 7, (5, 6))) 84 y_b = keras.backend.variable(np.random.random((5, 6, 7))) 85 objective_output = keras.losses.sparse_categorical_crossentropy(y_a, y_b) 86 assert keras.backend.eval(objective_output).shape == (5, 6) 87 88 y_a = keras.backend.variable(np.random.randint(0, 7, (6,))) 89 y_b = keras.backend.variable(np.random.random((6, 7))) 90 objective_output = keras.losses.sparse_categorical_crossentropy(y_a, y_b) 91 assert keras.backend.eval(objective_output).shape == (6,) 92 93 def test_serialization(self): 94 fn = keras.losses.get('mse') 95 config = keras.losses.serialize(fn) 96 new_fn = keras.losses.deserialize(config) 97 self.assertEqual(fn, new_fn) 98 99 def test_categorical_hinge(self): 100 y_pred = keras.backend.variable(np.array([[0.3, 0.2, 0.1], 101 [0.1, 0.2, 0.7]])) 102 y_true = keras.backend.variable(np.array([[0, 1, 0], [1, 0, 0]])) 103 expected_loss = ((0.3 - 0.2 + 1) + (0.7 - 0.1 + 1)) / 2.0 104 loss = keras.backend.eval(keras.losses.categorical_hinge(y_true, y_pred)) 105 self.assertAllClose(expected_loss, np.mean(loss)) 106 107 def test_serializing_loss_class(self): 108 orig_loss_class = _MSEMAELoss(0.3) 109 with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): 110 serialized = keras.losses.serialize(orig_loss_class) 111 112 with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): 113 deserialized = keras.losses.deserialize(serialized) 114 assert isinstance(deserialized, _MSEMAELoss) 115 assert deserialized.mse_fraction == 0.3 116 117 def test_serializing_model_with_loss_class(self): 118 tmpdir = self.get_temp_dir() 119 self.addCleanup(shutil.rmtree, tmpdir) 120 model_filename = os.path.join(tmpdir, 'custom_loss.h5') 121 122 with self.test_session(): 123 with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): 124 loss = _MSEMAELoss(0.3) 125 inputs = keras.layers.Input((2,)) 126 outputs = keras.layers.Dense(1, name='model_output')(inputs) 127 model = keras.models.Model(inputs, outputs) 128 model.compile(optimizer='sgd', loss={'model_output': loss}) 129 model.fit(np.random.rand(256, 2), np.random.rand(256, 1)) 130 131 if h5py is None: 132 return 133 134 model.save(model_filename) 135 136 with keras.utils.custom_object_scope({'_MSEMAELoss': _MSEMAELoss}): 137 loaded_model = keras.models.load_model(model_filename) 138 loaded_model.predict(np.random.rand(128, 2)) 139 140 141 if __name__ == '__main__': 142 test.main() 143