Home | History | Annotate | Download | only in layers
      1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Utility functions for summary creation."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import functools
     22 import re
     23 
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import standard_ops
     27 from tensorflow.python.summary import summary
     28 
     29 __all__ = [
     30     'summarize_tensor',
     31     'summarize_activation',
     32     'summarize_tensors',
     33     'summarize_collection',
     34     'summarize_variables',
     35     'summarize_weights',
     36     'summarize_biases',
     37     'summarize_activations',
     38 ]
     39 
     40 # TODO(wicke): add more unit tests for summarization functions.
     41 
     42 
     43 def _add_scalar_summary(tensor, tag=None):
     44   """Add a scalar summary operation for the tensor.
     45 
     46   Args:
     47     tensor: The tensor to summarize.
     48     tag: The tag to use, if None then use tensor's op's name.
     49 
     50   Returns:
     51     The created histogram summary.
     52 
     53   Raises:
     54     ValueError: If the tag is already in use or the rank is not 0.
     55   """
     56   tensor.get_shape().assert_has_rank(0)
     57   tag = tag or '%s_summary' % tensor.op.name
     58   return summary.scalar(tag, tensor)
     59 
     60 
     61 def _add_histogram_summary(tensor, tag=None):
     62   """Add a summary operation for the histogram of a tensor.
     63 
     64   Args:
     65     tensor: The tensor to summarize.
     66     tag: The tag to use, if None then use tensor's op's name.
     67 
     68   Returns:
     69     The created histogram summary.
     70 
     71   Raises:
     72     ValueError: If the tag is already in use.
     73   """
     74   tag = tag or '%s_summary' % tensor.op.name
     75   return summary.histogram(tag, tensor)
     76 
     77 
     78 def summarize_activation(op):
     79   """Summarize an activation.
     80 
     81   This applies the given activation and adds useful summaries specific to the
     82   activation.
     83 
     84   Args:
     85     op: The tensor to summarize (assumed to be a layer activation).
     86   Returns:
     87     The summary op created to summarize `op`.
     88   """
     89   if op.op.type in ('Relu', 'Softplus', 'Relu6'):
     90     # Using inputs to avoid floating point equality and/or epsilons.
     91     _add_scalar_summary(
     92         standard_ops.reduce_mean(
     93             standard_ops.to_float(
     94                 standard_ops.less(op.op.inputs[
     95                     0], standard_ops.cast(0.0, op.op.inputs[0].dtype)))),
     96         '%s/zeros' % op.op.name)
     97   if op.op.type == 'Relu6':
     98     _add_scalar_summary(
     99         standard_ops.reduce_mean(
    100             standard_ops.to_float(
    101                 standard_ops.greater(op.op.inputs[
    102                     0], standard_ops.cast(6.0, op.op.inputs[0].dtype)))),
    103         '%s/sixes' % op.op.name)
    104   return _add_histogram_summary(op, '%s/activation' % op.op.name)
    105 
    106 
    107 def summarize_tensor(tensor, tag=None):
    108   """Summarize a tensor using a suitable summary type.
    109 
    110   This function adds a summary op for `tensor`. The type of summary depends on
    111   the shape of `tensor`. For scalars, a `scalar_summary` is created, for all
    112   other tensors, `histogram_summary` is used.
    113 
    114   Args:
    115     tensor: The tensor to summarize
    116     tag: The tag to use, if None then use tensor's op's name.
    117 
    118   Returns:
    119     The summary op created or None for string tensors.
    120   """
    121   # Skips string tensors and boolean tensors (not handled by the summaries).
    122   if (tensor.dtype.is_compatible_with(dtypes.string) or
    123       tensor.dtype.base_dtype == dtypes.bool):
    124     return None
    125 
    126   if tensor.get_shape().ndims == 0:
    127     # For scalars, use a scalar summary.
    128     return _add_scalar_summary(tensor, tag)
    129   else:
    130     # We may land in here if the rank is still unknown. The histogram won't
    131     # hurt if this ends up being a scalar.
    132     return _add_histogram_summary(tensor, tag)
    133 
    134 
    135 def summarize_tensors(tensors, summarizer=summarize_tensor):
    136   """Summarize a set of tensors."""
    137   return [summarizer(tensor) for tensor in tensors]
    138 
    139 
    140 def summarize_collection(collection,
    141                          name_filter=None,
    142                          summarizer=summarize_tensor):
    143   """Summarize a graph collection of tensors, possibly filtered by name."""
    144   tensors = []
    145   for op in ops.get_collection(collection):
    146     if name_filter is None or re.match(name_filter, op.op.name):
    147       tensors.append(op)
    148   return summarize_tensors(tensors, summarizer)
    149 
    150 
    151 # Utility functions for commonly used collections
    152 summarize_variables = functools.partial(summarize_collection,
    153                                         ops.GraphKeys.GLOBAL_VARIABLES)
    154 
    155 summarize_weights = functools.partial(summarize_collection,
    156                                       ops.GraphKeys.WEIGHTS)
    157 
    158 summarize_biases = functools.partial(summarize_collection, ops.GraphKeys.BIASES)
    159 
    160 
    161 def summarize_activations(name_filter=None, summarizer=summarize_activation):
    162   """Summarize activations, using `summarize_activation` to summarize."""
    163   return summarize_collection(ops.GraphKeys.ACTIVATIONS, name_filter,
    164                               summarizer)
    165