1 # Copyright 2015 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Utility functions for summary creation.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import functools 22 import re 23 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.ops import standard_ops 27 from tensorflow.python.summary import summary 28 29 __all__ = [ 30 'summarize_tensor', 31 'summarize_activation', 32 'summarize_tensors', 33 'summarize_collection', 34 'summarize_variables', 35 'summarize_weights', 36 'summarize_biases', 37 'summarize_activations', 38 ] 39 40 # TODO(wicke): add more unit tests for summarization functions. 41 42 43 def _add_scalar_summary(tensor, tag=None): 44 """Add a scalar summary operation for the tensor. 45 46 Args: 47 tensor: The tensor to summarize. 48 tag: The tag to use, if None then use tensor's op's name. 49 50 Returns: 51 The created histogram summary. 52 53 Raises: 54 ValueError: If the tag is already in use or the rank is not 0. 55 """ 56 tensor.get_shape().assert_has_rank(0) 57 tag = tag or '%s_summary' % tensor.op.name 58 return summary.scalar(tag, tensor) 59 60 61 def _add_histogram_summary(tensor, tag=None): 62 """Add a summary operation for the histogram of a tensor. 63 64 Args: 65 tensor: The tensor to summarize. 66 tag: The tag to use, if None then use tensor's op's name. 67 68 Returns: 69 The created histogram summary. 70 71 Raises: 72 ValueError: If the tag is already in use. 73 """ 74 tag = tag or '%s_summary' % tensor.op.name 75 return summary.histogram(tag, tensor) 76 77 78 def summarize_activation(op): 79 """Summarize an activation. 80 81 This applies the given activation and adds useful summaries specific to the 82 activation. 83 84 Args: 85 op: The tensor to summarize (assumed to be a layer activation). 86 Returns: 87 The summary op created to summarize `op`. 88 """ 89 if op.op.type in ('Relu', 'Softplus', 'Relu6'): 90 # Using inputs to avoid floating point equality and/or epsilons. 91 _add_scalar_summary( 92 standard_ops.reduce_mean( 93 standard_ops.to_float( 94 standard_ops.less(op.op.inputs[ 95 0], standard_ops.cast(0.0, op.op.inputs[0].dtype)))), 96 '%s/zeros' % op.op.name) 97 if op.op.type == 'Relu6': 98 _add_scalar_summary( 99 standard_ops.reduce_mean( 100 standard_ops.to_float( 101 standard_ops.greater(op.op.inputs[ 102 0], standard_ops.cast(6.0, op.op.inputs[0].dtype)))), 103 '%s/sixes' % op.op.name) 104 return _add_histogram_summary(op, '%s/activation' % op.op.name) 105 106 107 def summarize_tensor(tensor, tag=None): 108 """Summarize a tensor using a suitable summary type. 109 110 This function adds a summary op for `tensor`. The type of summary depends on 111 the shape of `tensor`. For scalars, a `scalar_summary` is created, for all 112 other tensors, `histogram_summary` is used. 113 114 Args: 115 tensor: The tensor to summarize 116 tag: The tag to use, if None then use tensor's op's name. 117 118 Returns: 119 The summary op created or None for string tensors. 120 """ 121 # Skips string tensors and boolean tensors (not handled by the summaries). 122 if (tensor.dtype.is_compatible_with(dtypes.string) or 123 tensor.dtype.base_dtype == dtypes.bool): 124 return None 125 126 if tensor.get_shape().ndims == 0: 127 # For scalars, use a scalar summary. 128 return _add_scalar_summary(tensor, tag) 129 else: 130 # We may land in here if the rank is still unknown. The histogram won't 131 # hurt if this ends up being a scalar. 132 return _add_histogram_summary(tensor, tag) 133 134 135 def summarize_tensors(tensors, summarizer=summarize_tensor): 136 """Summarize a set of tensors.""" 137 return [summarizer(tensor) for tensor in tensors] 138 139 140 def summarize_collection(collection, 141 name_filter=None, 142 summarizer=summarize_tensor): 143 """Summarize a graph collection of tensors, possibly filtered by name.""" 144 tensors = [] 145 for op in ops.get_collection(collection): 146 if name_filter is None or re.match(name_filter, op.op.name): 147 tensors.append(op) 148 return summarize_tensors(tensors, summarizer) 149 150 151 # Utility functions for commonly used collections 152 summarize_variables = functools.partial(summarize_collection, 153 ops.GraphKeys.GLOBAL_VARIABLES) 154 155 summarize_weights = functools.partial(summarize_collection, 156 ops.GraphKeys.WEIGHTS) 157 158 summarize_biases = functools.partial(summarize_collection, ops.GraphKeys.BIASES) 159 160 161 def summarize_activations(name_filter=None, summarizer=summarize_activation): 162 """Summarize activations, using `summarize_activation` to summarize.""" 163 return summarize_collection(ops.GraphKeys.ACTIVATIONS, name_filter, 164 summarizer) 165