Home | History | Annotate | Download | only in keras
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 # pylint: disable=g-import-not-at-top
     16 """Callbacks: utilities called at certain points during model training.
     17 """
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import os
     23 
     24 import numpy as np
     25 
     26 from tensorflow.python.eager import context
     27 from tensorflow.python.eager import profiler
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.keras import backend as K
     30 from tensorflow.python.keras import callbacks
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import state_ops
     33 from tensorflow.python.ops import summary_ops_v2
     34 from tensorflow.python.ops import variables
     35 from tensorflow.python.platform import tf_logging as logging
     36 from tensorflow.python.summary import summary as tf_summary
     37 from tensorflow.python.training import saver
     38 from tensorflow.python.util.tf_export import keras_export
     39 
     40 
     41 @keras_export(v1=['keras.callbacks.TensorBoard'])
     42 class TensorBoard(callbacks.Callback):
     43   # pylint: disable=line-too-long
     44   """Enable visualizations for TensorBoard.
     45 
     46   TensorBoard is a visualization tool provided with TensorFlow.
     47 
     48   This callback logs events for TensorBoard, including:
     49   * Metrics summary plots
     50   * Training graph visualization
     51   * Activation histograms
     52   * Sampled profiling
     53 
     54   If you have installed TensorFlow with pip, you should be able
     55   to launch TensorBoard from the command line:
     56 
     57   ```sh
     58   tensorboard --logdir=path_to_your_logs
     59   ```
     60 
     61   You can find more information about TensorBoard
     62   [here](https://www.tensorflow.org/get_started/summaries_and_tensorboard).
     63 
     64   Arguments:
     65       log_dir: the path of the directory where to save the log files to be
     66         parsed by TensorBoard.
     67       histogram_freq: frequency (in epochs) at which to compute activation and
     68         weight histograms for the layers of the model. If set to 0, histograms
     69         won't be computed. Validation data (or split) must be specified for
     70         histogram visualizations.
     71       write_graph: whether to visualize the graph in TensorBoard. The log file
     72         can become quite large when write_graph is set to True.
     73       write_grads: whether to visualize gradient histograms in TensorBoard.
     74         `histogram_freq` must be greater than 0.
     75       batch_size: size of batch of inputs to feed to the network for histograms
     76         computation.
     77       write_images: whether to write model weights to visualize as image in
     78         TensorBoard.
     79       embeddings_freq: frequency (in epochs) at which selected embedding layers
     80         will be saved. If set to 0, embeddings won't be computed. Data to be
     81         visualized in TensorBoard's Embedding tab must be passed as
     82         `embeddings_data`.
     83       embeddings_layer_names: a list of names of layers to keep eye on. If None
     84         or empty list all the embedding layer will be watched.
     85       embeddings_metadata: a dictionary which maps layer name to a file name in
     86         which metadata for this embedding layer is saved. See the
     87           [details](https://www.tensorflow.org/how_tos/embedding_viz/#metadata_optional)
     88             about metadata files format. In case if the same metadata file is
     89             used for all embedding layers, string can be passed.
     90       embeddings_data: data to be embedded at layers specified in
     91         `embeddings_layer_names`. Numpy array (if the model has a single input)
     92         or list of Numpy arrays (if the model has multiple inputs). Learn [more
     93         about
     94             embeddings](https://www.tensorflow.org/programmers_guide/embedding)
     95       update_freq: `'batch'` or `'epoch'` or integer. When using `'batch'`,
     96         writes the losses and metrics to TensorBoard after each batch. The same
     97         applies for `'epoch'`. If using an integer, let's say `1000`, the
     98         callback will write the metrics and losses to TensorBoard every 1000
     99         samples. Note that writing too frequently to TensorBoard can slow down
    100         your training.
    101       profile_batch: Profile the batch to sample compute characteristics. By
    102         default, it will profile the second batch. Set profile_batch=0 to
    103         disable profiling.
    104 
    105   Raises:
    106       ValueError: If histogram_freq is set and no validation data is provided.
    107 
    108   @compatibility(eager)
    109   Using the `TensorBoard` callback will work when eager execution is enabled,
    110   with the restriction that outputting histogram summaries of weights and
    111   gradients is not supported. Consequently, `histogram_freq` will be ignored.
    112   @end_compatibility
    113   """
    114 
    115   # pylint: enable=line-too-long
    116 
    117   def __init__(self,
    118                log_dir='./logs',
    119                histogram_freq=0,
    120                batch_size=32,
    121                write_graph=True,
    122                write_grads=False,
    123                write_images=False,
    124                embeddings_freq=0,
    125                embeddings_layer_names=None,
    126                embeddings_metadata=None,
    127                embeddings_data=None,
    128                update_freq='epoch',
    129                profile_batch=2):
    130     super(TensorBoard, self).__init__()
    131     self.log_dir = log_dir
    132     self.histogram_freq = histogram_freq
    133     if self.histogram_freq and context.executing_eagerly():
    134       logging.warning(
    135           UserWarning('Weight and gradient histograms not supported for eager'
    136                       'execution, setting `histogram_freq` to `0`.'))
    137       self.histogram_freq = 0
    138     self.merged = None
    139     self.write_graph = write_graph
    140     self.write_grads = write_grads
    141     self.write_images = write_images
    142     self.batch_size = batch_size
    143     self._current_batch = 0
    144     self._total_batches_seen = 0
    145     self._total_val_batches_seen = 0
    146     self.embeddings_freq = embeddings_freq
    147     self.embeddings_layer_names = embeddings_layer_names
    148     self.embeddings_metadata = embeddings_metadata
    149     self.embeddings_data = embeddings_data
    150     if update_freq == 'batch':
    151       self.update_freq = 1
    152     else:
    153       self.update_freq = update_freq
    154     self._samples_seen = 0
    155     self._samples_seen_at_last_write = 0
    156     # TODO(fishx): Add a link to the full profiler tutorial.
    157     self._profile_batch = profile_batch
    158     # One profiler session is running if it is True.
    159     self._is_profiling = False
    160 
    161     # TensorBoard should only write summaries on the chief when in a
    162     # Multi-Worker setting.
    163     self._chief_worker_only = True
    164 
    165   def _init_writer(self, model):
    166     """Sets file writer."""
    167     if context.executing_eagerly():
    168       self.writer = summary_ops_v2.create_file_writer(self.log_dir)
    169       if not model.run_eagerly and self.write_graph:
    170         with self.writer.as_default():
    171           summary_ops_v2.graph(K.get_graph())
    172     elif self.write_graph:
    173       self.writer = tf_summary.FileWriter(self.log_dir, K.get_graph())
    174     else:
    175       self.writer = tf_summary.FileWriter(self.log_dir)
    176 
    177   def _make_histogram_ops(self, model):
    178     """Defines histogram ops when histogram_freq > 0."""
    179     # only make histogram summary op if it hasn't already been made
    180     if self.histogram_freq and self.merged is None:
    181       for layer in self.model.layers:
    182         for weight in layer.weights:
    183           mapped_weight_name = weight.name.replace(':', '_')
    184           tf_summary.histogram(mapped_weight_name, weight)
    185           if self.write_images:
    186             w_img = array_ops.squeeze(weight)
    187             shape = K.int_shape(w_img)
    188             if len(shape) == 2:  # dense layer kernel case
    189               if shape[0] > shape[1]:
    190                 w_img = array_ops.transpose(w_img)
    191                 shape = K.int_shape(w_img)
    192               w_img = array_ops.reshape(w_img, [1, shape[0], shape[1], 1])
    193             elif len(shape) == 3:  # convnet case
    194               if K.image_data_format() == 'channels_last':
    195                 # switch to channels_first to display
    196                 # every kernel as a separate image
    197                 w_img = array_ops.transpose(w_img, perm=[2, 0, 1])
    198                 shape = K.int_shape(w_img)
    199               w_img = array_ops.reshape(w_img,
    200                                         [shape[0], shape[1], shape[2], 1])
    201             elif len(shape) == 1:  # bias case
    202               w_img = array_ops.reshape(w_img, [1, shape[0], 1, 1])
    203             else:
    204               # not possible to handle 3D convnets etc.
    205               continue
    206 
    207             shape = K.int_shape(w_img)
    208             assert len(shape) == 4 and shape[-1] in [1, 3, 4]
    209             tf_summary.image(mapped_weight_name, w_img)
    210 
    211         if self.write_grads:
    212           for weight in layer.trainable_weights:
    213             mapped_weight_name = weight.name.replace(':', '_')
    214             grads = model.optimizer.get_gradients(model.total_loss, weight)
    215 
    216             def is_indexed_slices(grad):
    217               return type(grad).__name__ == 'IndexedSlices'
    218 
    219             grads = [
    220                 grad.values if is_indexed_slices(grad) else grad
    221                 for grad in grads
    222             ]
    223             tf_summary.histogram('{}_grad'.format(mapped_weight_name), grads)
    224 
    225         if hasattr(layer, 'output'):
    226           if isinstance(layer.output, list):
    227             for i, output in enumerate(layer.output):
    228               tf_summary.histogram('{}_out_{}'.format(layer.name, i), output)
    229           else:
    230             tf_summary.histogram('{}_out'.format(layer.name), layer.output)
    231 
    232   def set_model(self, model):
    233     """Sets Keras model and creates summary ops."""
    234 
    235     self.model = model
    236     self._init_writer(model)
    237     # histogram summaries only enabled in graph mode
    238     if not context.executing_eagerly():
    239       self._make_histogram_ops(model)
    240       self.merged = tf_summary.merge_all()
    241 
    242     # If both embedding_freq and embeddings_data are available, we will
    243     # visualize embeddings.
    244     if self.embeddings_freq and self.embeddings_data is not None:
    245       # Avoid circular dependency.
    246       from tensorflow.python.keras.engine import training_utils  # pylint: disable=g-import-not-at-top
    247       self.embeddings_data = training_utils.standardize_input_data(
    248           self.embeddings_data, model.input_names)
    249 
    250       # If embedding_layer_names are not provided, get all of the embedding
    251       # layers from the model.
    252       embeddings_layer_names = self.embeddings_layer_names
    253       if not embeddings_layer_names:
    254         embeddings_layer_names = [
    255             layer.name
    256             for layer in self.model.layers
    257             if type(layer).__name__ == 'Embedding'
    258         ]
    259 
    260       self.assign_embeddings = []
    261       embeddings_vars = {}
    262 
    263       self.batch_id = batch_id = array_ops.placeholder(dtypes.int32)
    264       self.step = step = array_ops.placeholder(dtypes.int32)
    265 
    266       for layer in self.model.layers:
    267         if layer.name in embeddings_layer_names:
    268           embedding_input = self.model.get_layer(layer.name).output
    269           embedding_size = np.prod(embedding_input.shape[1:])
    270           embedding_input = array_ops.reshape(embedding_input,
    271                                               (step, int(embedding_size)))
    272           shape = (self.embeddings_data[0].shape[0], int(embedding_size))
    273           embedding = variables.Variable(
    274               array_ops.zeros(shape), name=layer.name + '_embedding')
    275           embeddings_vars[layer.name] = embedding
    276           batch = state_ops.assign(embedding[batch_id:batch_id + step],
    277                                    embedding_input)
    278           self.assign_embeddings.append(batch)
    279 
    280       self.saver = saver.Saver(list(embeddings_vars.values()))
    281 
    282       # Create embeddings_metadata dictionary
    283       if isinstance(self.embeddings_metadata, str):
    284         embeddings_metadata = {
    285             layer_name: self.embeddings_metadata
    286             for layer_name in embeddings_vars.keys()
    287         }
    288       else:
    289         # If embedding_metadata is already a dictionary
    290         embeddings_metadata = self.embeddings_metadata
    291 
    292       try:
    293         from tensorboard.plugins import projector
    294       except ImportError:
    295         raise ImportError('Failed to import TensorBoard. Please make sure that '
    296                           'TensorBoard integration is complete."')
    297 
    298       # TODO(psv): Add integration tests to test embedding visualization
    299       # with TensorBoard callback. We are unable to write a unit test for this
    300       # because TensorBoard dependency assumes TensorFlow package is installed.
    301       config = projector.ProjectorConfig()
    302       for layer_name, tensor in embeddings_vars.items():
    303         embedding = config.embeddings.add()
    304         embedding.tensor_name = tensor.name
    305 
    306         if (embeddings_metadata is not None and
    307             layer_name in embeddings_metadata):
    308           embedding.metadata_path = embeddings_metadata[layer_name]
    309 
    310       projector.visualize_embeddings(self.writer, config)
    311 
    312   def _fetch_callback(self, summary):
    313     self.writer.add_summary(summary, self._total_val_batches_seen)
    314     self._total_val_batches_seen += 1
    315 
    316   def _write_custom_summaries(self, step, logs=None):
    317     """Writes metrics out as custom scalar summaries.
    318 
    319     Arguments:
    320         step: the global step to use for TensorBoard.
    321         logs: dict. Keys are scalar summary names, values are
    322             NumPy scalars.
    323 
    324     """
    325     logs = logs or {}
    326     if context.executing_eagerly():
    327       # use v2 summary ops
    328       with self.writer.as_default(), summary_ops_v2.always_record_summaries():
    329         for name, value in logs.items():
    330           if isinstance(value, np.ndarray):
    331             value = value.item()
    332           summary_ops_v2.scalar(name, value, step=step)
    333     else:
    334       # use FileWriter from v1 summary
    335       for name, value in logs.items():
    336         if isinstance(value, np.ndarray):
    337           value = value.item()
    338         summary = tf_summary.Summary()
    339         summary_value = summary.value.add()
    340         summary_value.simple_value = value
    341         summary_value.tag = name
    342         self.writer.add_summary(summary, step)
    343     self.writer.flush()
    344 
    345   def on_batch_end(self, batch, logs=None):
    346     """Writes scalar summaries for metrics on every training batch.
    347 
    348     Performs profiling if current batch is in profiler_batches.
    349     """
    350     # Don't output batch_size and batch number as TensorBoard summaries
    351     logs = logs or {}
    352     self._samples_seen += logs.get('size', 1)
    353     samples_seen_since = self._samples_seen - self._samples_seen_at_last_write
    354     if self.update_freq != 'epoch' and samples_seen_since >= self.update_freq:
    355       batch_logs = {('batch_' + k): v
    356                     for k, v in logs.items()
    357                     if k not in ['batch', 'size', 'num_steps']}
    358       self._write_custom_summaries(self._total_batches_seen, batch_logs)
    359       self._samples_seen_at_last_write = self._samples_seen
    360     self._total_batches_seen += 1
    361     if self._is_profiling:
    362       profiler.save(self.log_dir, profiler.stop())
    363       self._is_profiling = False
    364     elif (not self._is_profiling and
    365           self._total_batches_seen == self._profile_batch - 1):
    366       profiler.start()
    367       self._is_profiling = True
    368 
    369   def on_train_begin(self, logs=None):
    370     if self._profile_batch == 1:
    371       profiler.start()
    372       self._is_profiling = True
    373 
    374   def on_epoch_begin(self, epoch, logs=None):
    375     """Add histogram op to Model eval_function callbacks, reset batch count."""
    376 
    377     # check if histogram summary should be run for this epoch
    378     if self.histogram_freq and epoch % self.histogram_freq == 0:
    379       self._epoch = epoch
    380       # pylint: disable=protected-access
    381       # add the histogram summary op if it should run this epoch
    382       self.model._make_test_function()
    383       if self.merged not in self.model.test_function.fetches:
    384         self.model.test_function.fetches.append(self.merged)
    385         self.model.test_function.fetch_callbacks[
    386             self.merged] = self._fetch_callback
    387       # pylint: enable=protected-access
    388 
    389   def on_epoch_end(self, epoch, logs=None):
    390     """Checks if summary ops should run next epoch, logs scalar summaries."""
    391 
    392     # don't output batch_size and
    393     # batch number as TensorBoard summaries
    394     logs = {('epoch_' + k): v
    395             for k, v in logs.items()
    396             if k not in ['batch', 'size', 'num_steps']}
    397     if self.update_freq == 'epoch':
    398       step = epoch
    399     else:
    400       step = self._samples_seen
    401     self._write_custom_summaries(step, logs)
    402 
    403     # pop the histogram summary op after each epoch
    404     if self.histogram_freq:
    405       # pylint: disable=protected-access
    406       if self.merged in self.model.test_function.fetches:
    407         self.model.test_function.fetches.remove(self.merged)
    408       if self.merged in self.model.test_function.fetch_callbacks:
    409         self.model.test_function.fetch_callbacks.pop(self.merged)
    410       # pylint: enable=protected-access
    411 
    412     if self.embeddings_data is None and self.embeddings_freq:
    413       raise ValueError('To visualize embeddings, embeddings_data must '
    414                        'be provided.')
    415 
    416     if self.embeddings_freq and self.embeddings_data is not None:
    417       if epoch % self.embeddings_freq == 0:
    418         # We need a second forward-pass here because we're passing
    419         # the `embeddings_data` explicitly. This design allows to pass
    420         # arbitrary data as `embeddings_data` and results from the fact
    421         # that we need to know the size of the `tf.Variable`s which
    422         # hold the embeddings in `set_model`. At this point, however,
    423         # the `validation_data` is not yet set.
    424 
    425         embeddings_data = self.embeddings_data
    426         n_samples = embeddings_data[0].shape[0]
    427         i = 0
    428         sess = K.get_session()
    429         while i < n_samples:
    430           step = min(self.batch_size, n_samples - i)
    431           batch = slice(i, i + step)
    432 
    433           if isinstance(self.model.input, list):
    434             feed_dict = {
    435                 model_input: embeddings_data[idx][batch]
    436                 for idx, model_input in enumerate(self.model.input)
    437             }
    438           else:
    439             feed_dict = {self.model.input: embeddings_data[0][batch]}
    440 
    441           feed_dict.update({self.batch_id: i, self.step: step})
    442 
    443           if not isinstance(K.learning_phase(), int):
    444             feed_dict[K.learning_phase()] = False
    445 
    446           sess.run(self.assign_embeddings, feed_dict=feed_dict)
    447           self.saver.save(sess,
    448                           os.path.join(self.log_dir, 'keras_embedding.ckpt'),
    449                           epoch)
    450 
    451           i += self.batch_size
    452 
    453   def on_train_end(self, logs=None):
    454     if self._is_profiling:
    455       profiler.save(self.log_dir, profiler.stop())
    456       self._is_profiling = False
    457     self.writer.close()
    458