Home | History | Annotate | Download | only in keras
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for Keras callbacks."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import os
     22 import shutil
     23 import tempfile
     24 
     25 import numpy as np
     26 
     27 from tensorflow.core.framework import summary_pb2
     28 from tensorflow.python import keras
     29 from tensorflow.python.framework import test_util
     30 from tensorflow.python.keras import callbacks_v1
     31 from tensorflow.python.keras import testing_utils
     32 from tensorflow.python.platform import test
     33 from tensorflow.python.training import adam
     34 
     35 
     36 TRAIN_SAMPLES = 10
     37 TEST_SAMPLES = 10
     38 NUM_CLASSES = 2
     39 INPUT_DIM = 3
     40 NUM_HIDDEN = 5
     41 BATCH_SIZE = 5
     42 
     43 
     44 class TestTensorBoardV1(test.TestCase):
     45 
     46   @test_util.run_deprecated_v1
     47   def test_TensorBoard(self):
     48     np.random.seed(1337)
     49 
     50     temp_dir = self.get_temp_dir()
     51     self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
     52 
     53     (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
     54         train_samples=TRAIN_SAMPLES,
     55         test_samples=TEST_SAMPLES,
     56         input_shape=(INPUT_DIM,),
     57         num_classes=NUM_CLASSES)
     58     y_test = keras.utils.to_categorical(y_test)
     59     y_train = keras.utils.to_categorical(y_train)
     60 
     61     def data_generator(train):
     62       if train:
     63         max_batch_index = len(x_train) // BATCH_SIZE
     64       else:
     65         max_batch_index = len(x_test) // BATCH_SIZE
     66       i = 0
     67       while 1:
     68         if train:
     69           yield (x_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE],
     70                  y_train[i * BATCH_SIZE:(i + 1) * BATCH_SIZE])
     71         else:
     72           yield (x_test[i * BATCH_SIZE:(i + 1) * BATCH_SIZE],
     73                  y_test[i * BATCH_SIZE:(i + 1) * BATCH_SIZE])
     74         i += 1
     75         i %= max_batch_index
     76 
     77     # case: Sequential
     78     with self.cached_session():
     79       model = keras.models.Sequential()
     80       model.add(
     81           keras.layers.Dense(
     82               NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
     83       # non_trainable_weights: moving_variance, moving_mean
     84       model.add(keras.layers.BatchNormalization())
     85       model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
     86       model.compile(
     87           loss='categorical_crossentropy',
     88           optimizer='sgd',
     89           metrics=['accuracy'])
     90       tsb = callbacks_v1.TensorBoard(
     91           log_dir=temp_dir,
     92           histogram_freq=1,
     93           write_images=True,
     94           write_grads=True,
     95           batch_size=5)
     96       cbks = [tsb]
     97 
     98       # fit with validation data
     99       model.fit(
    100           x_train,
    101           y_train,
    102           batch_size=BATCH_SIZE,
    103           validation_data=(x_test, y_test),
    104           callbacks=cbks,
    105           epochs=3,
    106           verbose=0)
    107 
    108       # fit with validation data and accuracy
    109       model.fit(
    110           x_train,
    111           y_train,
    112           batch_size=BATCH_SIZE,
    113           validation_data=(x_test, y_test),
    114           callbacks=cbks,
    115           epochs=2,
    116           verbose=0)
    117 
    118       # fit generator with validation data
    119       model.fit_generator(
    120           data_generator(True),
    121           len(x_train),
    122           epochs=2,
    123           validation_data=(x_test, y_test),
    124           callbacks=cbks,
    125           verbose=0)
    126 
    127       # fit generator without validation data
    128       # histogram_freq must be zero
    129       tsb.histogram_freq = 0
    130       model.fit_generator(
    131           data_generator(True),
    132           len(x_train),
    133           epochs=2,
    134           callbacks=cbks,
    135           verbose=0)
    136 
    137       # fit generator with validation data and accuracy
    138       tsb.histogram_freq = 1
    139       model.fit_generator(
    140           data_generator(True),
    141           len(x_train),
    142           epochs=2,
    143           validation_data=(x_test, y_test),
    144           callbacks=cbks,
    145           verbose=0)
    146 
    147       # fit generator without validation data and accuracy
    148       tsb.histogram_freq = 0
    149       model.fit_generator(
    150           data_generator(True), len(x_train), epochs=2, callbacks=cbks)
    151       assert os.path.exists(temp_dir)
    152 
    153   @test_util.run_deprecated_v1
    154   def test_TensorBoard_multi_input_output(self):
    155     np.random.seed(1337)
    156     tmpdir = self.get_temp_dir()
    157     self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
    158 
    159     with self.cached_session():
    160       filepath = os.path.join(tmpdir, 'logs')
    161 
    162       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
    163           train_samples=TRAIN_SAMPLES,
    164           test_samples=TEST_SAMPLES,
    165           input_shape=(INPUT_DIM,),
    166           num_classes=NUM_CLASSES)
    167       y_test = keras.utils.to_categorical(y_test)
    168       y_train = keras.utils.to_categorical(y_train)
    169 
    170       def data_generator(train):
    171         if train:
    172           max_batch_index = len(x_train) // BATCH_SIZE
    173         else:
    174           max_batch_index = len(x_test) // BATCH_SIZE
    175         i = 0
    176         while 1:
    177           if train:
    178             # simulate multi-input/output models
    179             yield ([x_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]] * 2,
    180                    [y_train[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]] * 2)
    181           else:
    182             yield ([x_test[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]] * 2,
    183                    [y_test[i * BATCH_SIZE: (i + 1) * BATCH_SIZE]] * 2)
    184           i += 1
    185           i %= max_batch_index
    186 
    187       inp1 = keras.Input((INPUT_DIM,))
    188       inp2 = keras.Input((INPUT_DIM,))
    189       inp = keras.layers.add([inp1, inp2])
    190       hidden = keras.layers.Dense(2, activation='relu')(inp)
    191       hidden = keras.layers.Dropout(0.1)(hidden)
    192       output1 = keras.layers.Dense(NUM_CLASSES, activation='softmax')(hidden)
    193       output2 = keras.layers.Dense(NUM_CLASSES, activation='softmax')(hidden)
    194       model = keras.models.Model([inp1, inp2], [output1, output2])
    195       model.compile(loss='categorical_crossentropy',
    196                     optimizer='sgd',
    197                     metrics=['accuracy'])
    198 
    199       # we must generate new callbacks for each test, as they aren't stateless
    200       def callbacks_factory(histogram_freq):
    201         return [
    202             callbacks_v1.TensorBoard(
    203                 log_dir=filepath,
    204                 histogram_freq=histogram_freq,
    205                 write_images=True,
    206                 write_grads=True,
    207                 batch_size=5)
    208         ]
    209 
    210       # fit without validation data
    211       model.fit([x_train] * 2, [y_train] * 2, batch_size=BATCH_SIZE,
    212                 callbacks=callbacks_factory(histogram_freq=0), epochs=3)
    213 
    214       # fit with validation data and accuracy
    215       model.fit([x_train] * 2, [y_train] * 2, batch_size=BATCH_SIZE,
    216                 validation_data=([x_test] * 2, [y_test] * 2),
    217                 callbacks=callbacks_factory(histogram_freq=1), epochs=2)
    218 
    219       # fit generator without validation data
    220       model.fit_generator(data_generator(True), len(x_train), epochs=2,
    221                           callbacks=callbacks_factory(histogram_freq=0))
    222 
    223       # fit generator with validation data and accuracy
    224       model.fit_generator(data_generator(True), len(x_train), epochs=2,
    225                           validation_data=([x_test] * 2, [y_test] * 2),
    226                           callbacks=callbacks_factory(histogram_freq=1))
    227       assert os.path.isdir(filepath)
    228 
    229   @test_util.run_deprecated_v1
    230   def test_Tensorboard_histogram_summaries_in_test_function(self):
    231 
    232     class FileWriterStub(object):
    233 
    234       def __init__(self, logdir, graph=None):
    235         self.logdir = logdir
    236         self.graph = graph
    237         self.steps_seen = []
    238 
    239       def add_summary(self, summary, global_step):
    240         summary_obj = summary_pb2.Summary()
    241 
    242         # ensure a valid Summary proto is being sent
    243         if isinstance(summary, bytes):
    244           summary_obj.ParseFromString(summary)
    245         else:
    246           assert isinstance(summary, summary_pb2.Summary)
    247           summary_obj = summary
    248 
    249         # keep track of steps seen for the merged_summary op,
    250         # which contains the histogram summaries
    251         if len(summary_obj.value) > 1:
    252           self.steps_seen.append(global_step)
    253 
    254       def flush(self):
    255         pass
    256 
    257       def close(self):
    258         pass
    259 
    260     def _init_writer(obj, _):
    261       obj.writer = FileWriterStub(obj.log_dir)
    262 
    263     np.random.seed(1337)
    264     tmpdir = self.get_temp_dir()
    265     self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
    266     (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
    267         train_samples=TRAIN_SAMPLES,
    268         test_samples=TEST_SAMPLES,
    269         input_shape=(INPUT_DIM,),
    270         num_classes=NUM_CLASSES)
    271     y_test = keras.utils.to_categorical(y_test)
    272     y_train = keras.utils.to_categorical(y_train)
    273 
    274     with self.cached_session():
    275       model = keras.models.Sequential()
    276       model.add(
    277           keras.layers.Dense(
    278               NUM_HIDDEN, input_dim=INPUT_DIM, activation='relu'))
    279       # non_trainable_weights: moving_variance, moving_mean
    280       model.add(keras.layers.BatchNormalization())
    281       model.add(keras.layers.Dense(NUM_CLASSES, activation='softmax'))
    282       model.compile(
    283           loss='categorical_crossentropy',
    284           optimizer='sgd',
    285           metrics=['accuracy'])
    286       callbacks_v1.TensorBoard._init_writer = _init_writer
    287       tsb = callbacks_v1.TensorBoard(
    288           log_dir=tmpdir,
    289           histogram_freq=1,
    290           write_images=True,
    291           write_grads=True,
    292           batch_size=5)
    293       cbks = [tsb]
    294 
    295       # fit with validation data
    296       model.fit(
    297           x_train,
    298           y_train,
    299           batch_size=BATCH_SIZE,
    300           validation_data=(x_test, y_test),
    301           callbacks=cbks,
    302           epochs=3,
    303           verbose=0)
    304 
    305       self.assertAllEqual(tsb.writer.steps_seen, [0, 1, 2, 3, 4, 5])
    306 
    307   @test_util.run_deprecated_v1
    308   def test_Tensorboard_histogram_summaries_with_generator(self):
    309     np.random.seed(1337)
    310     tmpdir = self.get_temp_dir()
    311     self.addCleanup(shutil.rmtree, tmpdir, ignore_errors=True)
    312 
    313     def generator():
    314       x = np.random.randn(10, 100).astype(np.float32)
    315       y = np.random.randn(10, 10).astype(np.float32)
    316       while True:
    317         yield x, y
    318 
    319     with self.cached_session():
    320       model = testing_utils.get_small_sequential_mlp(
    321           num_hidden=10, num_classes=10, input_dim=100)
    322       model.compile(
    323           loss='categorical_crossentropy',
    324           optimizer='sgd',
    325           metrics=['accuracy'])
    326       tsb = callbacks_v1.TensorBoard(
    327           log_dir=tmpdir,
    328           histogram_freq=1,
    329           write_images=True,
    330           write_grads=True,
    331           batch_size=5)
    332       cbks = [tsb]
    333 
    334       # fit with validation generator
    335       model.fit_generator(
    336           generator(),
    337           steps_per_epoch=2,
    338           epochs=2,
    339           validation_data=generator(),
    340           validation_steps=2,
    341           callbacks=cbks,
    342           verbose=0)
    343 
    344       with self.assertRaises(ValueError):
    345         # fit with validation generator but no
    346         # validation_steps
    347         model.fit_generator(
    348             generator(),
    349             steps_per_epoch=2,
    350             epochs=2,
    351             validation_data=generator(),
    352             callbacks=cbks,
    353             verbose=0)
    354 
    355       self.assertTrue(os.path.exists(tmpdir))
    356 
    357   def test_TensorBoard_with_ReduceLROnPlateau(self):
    358     with self.cached_session():
    359       temp_dir = self.get_temp_dir()
    360       self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    361 
    362       (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
    363           train_samples=TRAIN_SAMPLES,
    364           test_samples=TEST_SAMPLES,
    365           input_shape=(INPUT_DIM,),
    366           num_classes=NUM_CLASSES)
    367       y_test = keras.utils.to_categorical(y_test)
    368       y_train = keras.utils.to_categorical(y_train)
    369 
    370       model = testing_utils.get_small_sequential_mlp(
    371           num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
    372       model.compile(
    373           loss='binary_crossentropy', optimizer='sgd', metrics=['accuracy'])
    374 
    375       cbks = [
    376           keras.callbacks.ReduceLROnPlateau(
    377               monitor='val_loss', factor=0.5, patience=4, verbose=1),
    378           callbacks_v1.TensorBoard(log_dir=temp_dir)
    379       ]
    380 
    381       model.fit(
    382           x_train,
    383           y_train,
    384           batch_size=BATCH_SIZE,
    385           validation_data=(x_test, y_test),
    386           callbacks=cbks,
    387           epochs=2,
    388           verbose=0)
    389 
    390       assert os.path.exists(temp_dir)
    391 
    392   @test_util.run_deprecated_v1
    393   def test_Tensorboard_batch_logging(self):
    394 
    395     class FileWriterStub(object):
    396 
    397       def __init__(self, logdir, graph=None):
    398         self.logdir = logdir
    399         self.graph = graph
    400         self.batches_logged = []
    401         self.summary_values = []
    402         self.summary_tags = []
    403 
    404       def add_summary(self, summary, step):
    405         self.summary_values.append(summary.value[0].simple_value)
    406         self.summary_tags.append(summary.value[0].tag)
    407         self.batches_logged.append(step)
    408 
    409       def flush(self):
    410         pass
    411 
    412       def close(self):
    413         pass
    414 
    415     temp_dir = self.get_temp_dir()
    416     self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    417 
    418     tb_cbk = callbacks_v1.TensorBoard(temp_dir, update_freq='batch')
    419     tb_cbk.writer = FileWriterStub(temp_dir)
    420 
    421     for batch in range(5):
    422       tb_cbk.on_batch_end(batch, {'acc': batch})
    423     self.assertEqual(tb_cbk.writer.batches_logged, [0, 1, 2, 3, 4])
    424     self.assertEqual(tb_cbk.writer.summary_values, [0., 1., 2., 3., 4.])
    425     self.assertEqual(tb_cbk.writer.summary_tags, ['batch_acc'] * 5)
    426 
    427   @test_util.run_deprecated_v1
    428   def test_Tensorboard_epoch_and_batch_logging(self):
    429 
    430     class FileWriterStub(object):
    431 
    432       def __init__(self, logdir, graph=None):
    433         self.logdir = logdir
    434         self.graph = graph
    435 
    436       def add_summary(self, summary, step):
    437         if 'batch_' in summary.value[0].tag:
    438           self.batch_summary = (step, summary)
    439         elif 'epoch_' in summary.value[0].tag:
    440           self.epoch_summary = (step, summary)
    441 
    442       def flush(self):
    443         pass
    444 
    445       def close(self):
    446         pass
    447 
    448     temp_dir = self.get_temp_dir()
    449     self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    450 
    451     tb_cbk = callbacks_v1.TensorBoard(temp_dir, update_freq='batch')
    452     tb_cbk.writer = FileWriterStub(temp_dir)
    453 
    454     tb_cbk.on_batch_end(0, {'acc': 5.0})
    455     tb_cbk.on_train_end()
    456     batch_step, batch_summary = tb_cbk.writer.batch_summary
    457     self.assertEqual(batch_step, 0)
    458     self.assertEqual(batch_summary.value[0].simple_value, 5.0)
    459 
    460     tb_cbk = callbacks_v1.TensorBoard(temp_dir, update_freq='epoch')
    461     tb_cbk.writer = FileWriterStub(temp_dir)
    462     tb_cbk.on_epoch_end(0, {'acc': 10.0})
    463     tb_cbk.on_train_end()
    464     epoch_step, epoch_summary = tb_cbk.writer.epoch_summary
    465     self.assertEqual(epoch_step, 0)
    466     self.assertEqual(epoch_summary.value[0].simple_value, 10.0)
    467 
    468   @test_util.run_in_graph_and_eager_modes
    469   def test_Tensorboard_eager(self):
    470     temp_dir = tempfile.mkdtemp(dir=self.get_temp_dir())
    471     self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    472 
    473     (x_train, y_train), (x_test, y_test) = testing_utils.get_test_data(
    474         train_samples=TRAIN_SAMPLES,
    475         test_samples=TEST_SAMPLES,
    476         input_shape=(INPUT_DIM,),
    477         num_classes=NUM_CLASSES)
    478     y_test = keras.utils.to_categorical(y_test)
    479     y_train = keras.utils.to_categorical(y_train)
    480 
    481     model = testing_utils.get_small_sequential_mlp(
    482         num_hidden=NUM_HIDDEN, num_classes=NUM_CLASSES, input_dim=INPUT_DIM)
    483     model.compile(
    484         loss='binary_crossentropy',
    485         optimizer=adam.AdamOptimizer(0.01),
    486         metrics=['accuracy'])
    487 
    488     cbks = [callbacks_v1.TensorBoard(log_dir=temp_dir)]
    489 
    490     model.fit(
    491         x_train,
    492         y_train,
    493         batch_size=BATCH_SIZE,
    494         validation_data=(x_test, y_test),
    495         callbacks=cbks,
    496         epochs=2,
    497         verbose=0)
    498 
    499     self.assertTrue(os.path.exists(temp_dir))
    500 
    501   @test_util.run_deprecated_v1
    502   def test_TensorBoard_update_freq(self):
    503 
    504     class FileWriterStub(object):
    505 
    506       def __init__(self, logdir, graph=None):
    507         self.logdir = logdir
    508         self.graph = graph
    509         self.batch_summaries = []
    510         self.epoch_summaries = []
    511 
    512       def add_summary(self, summary, step):
    513         if 'batch_' in summary.value[0].tag:
    514           self.batch_summaries.append((step, summary))
    515         elif 'epoch_' in summary.value[0].tag:
    516           self.epoch_summaries.append((step, summary))
    517 
    518       def flush(self):
    519         pass
    520 
    521       def close(self):
    522         pass
    523 
    524     temp_dir = self.get_temp_dir()
    525     self.addCleanup(shutil.rmtree, temp_dir, ignore_errors=True)
    526 
    527     # Epoch mode
    528     tb_cbk = callbacks_v1.TensorBoard(temp_dir, update_freq='epoch')
    529     tb_cbk.writer = FileWriterStub(temp_dir)
    530 
    531     tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1})
    532     self.assertEqual(tb_cbk.writer.batch_summaries, [])
    533     tb_cbk.on_epoch_end(0, {'acc': 10.0, 'size': 1})
    534     self.assertEqual(len(tb_cbk.writer.epoch_summaries), 1)
    535     tb_cbk.on_train_end()
    536 
    537     # Batch mode
    538     tb_cbk = callbacks_v1.TensorBoard(temp_dir, update_freq='batch')
    539     tb_cbk.writer = FileWriterStub(temp_dir)
    540 
    541     tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1})
    542     self.assertEqual(len(tb_cbk.writer.batch_summaries), 1)
    543     tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 1})
    544     self.assertEqual(len(tb_cbk.writer.batch_summaries), 2)
    545     self.assertFalse(tb_cbk.writer.epoch_summaries)
    546     tb_cbk.on_train_end()
    547 
    548     # Integer mode
    549     tb_cbk = callbacks_v1.TensorBoard(temp_dir, update_freq=20)
    550     tb_cbk.writer = FileWriterStub(temp_dir)
    551 
    552     tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10})
    553     self.assertFalse(tb_cbk.writer.batch_summaries)
    554     tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10})
    555     self.assertEqual(len(tb_cbk.writer.batch_summaries), 1)
    556     tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10})
    557     self.assertEqual(len(tb_cbk.writer.batch_summaries), 1)
    558     tb_cbk.on_batch_end(0, {'acc': 5.0, 'size': 10})
    559     self.assertEqual(len(tb_cbk.writer.batch_summaries), 2)
    560     tb_cbk.on_batch_end(0, {'acc': 10.0, 'size': 10})
    561     self.assertEqual(len(tb_cbk.writer.batch_summaries), 2)
    562     self.assertFalse(tb_cbk.writer.epoch_summaries)
    563     tb_cbk.on_train_end()
    564 
    565 
    566 if __name__ == '__main__':
    567   test.main()
    568