1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Experimental API for gathering statistics from `tf.data` pipelines.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.ops import iterator_ops 22 from tensorflow.python.data.util import nest 23 from tensorflow.python.data.util import sparse 24 from tensorflow.python.framework import dtypes 25 from tensorflow.python.framework import ops 26 from tensorflow.python.ops import gen_dataset_ops 27 28 29 class StatsAggregator(object): 30 """A stateful resource that aggregates statistics from one or more iterators. 31 32 To record statistics, use one of the custom transformation functions defined 33 in this module when defining your @{tf.data.Dataset}. All statistics will be 34 aggregated by the `StatsAggregator` that is associated with a particular 35 iterator (see below). For example, to record the total number of bytes 36 produced by iterating over a dataset: 37 38 ```python 39 dataset = ... 40 dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes")) 41 ``` 42 43 To associate a `StatsAggregator` with a @{tf.data.Iterator} object, use 44 the following pattern: 45 46 ```python 47 dataset = ... 48 iterator = dataset.make_one_shot_iterator() 49 stats_aggregator = stats_ops.StatsAggregator() 50 set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator) 51 52 with tf.Session() as sess: 53 # Running `set_op` will associate `iterator` with `stats_aggregator`. 54 sess.run(set_op) 55 ``` 56 57 To get a protocol buffer summary of the currently aggregated statistics, 58 use the `StatsAggregator.get_summary()` tensor. The easiest way to do this 59 is to add the returned tensor to the @{tf.GraphKeys.SUMMARIES} collection, 60 so that the summaries will be included with any existing summaries. 61 62 ```python 63 stats_aggregator = stats_ops.StatsAggregator() 64 stats_summary = stats_aggregator.get_summary() 65 tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary) 66 ``` 67 68 Note: This interface is experimental and expected to change. In particular, 69 we expect to add other implementations of `StatsAggregator` that provide 70 different ways of exporting statistics, and add more types of statistics. 71 """ 72 73 def __init__(self): 74 """Creates a `StatsAggregator`.""" 75 self._resource = gen_dataset_ops.stats_aggregator_handle() 76 77 def get_summary(self): 78 """Returns a string @{tf.Tensor} that summarizes the aggregated statistics. 79 80 The returned tensor will contain a serialized @{tf.summary.Summary} protocol 81 buffer, which can be used with the standard TensorBoard logging facilities. 82 83 Returns: 84 A scalar string @{tf.Tensor} that summarizes the aggregated statistics. 85 """ 86 return gen_dataset_ops.stats_aggregator_summary(self._resource) 87 88 def subscribe(self, iterator): 89 """Returns a @{tf.Operation} to associate this aggregator with `iterator`. 90 91 Note: Each @{tf.data.Iterator} can be associated with at most one 92 `StatsAggregator`. After running the operation that this function 93 returns, all statistics recorded in the iteration of `iterator` 94 will be stored in `stats_aggregator`. 95 96 Args: 97 iterator: A @{tf.data.Iterator} object. 98 99 Returns: 100 A @{tf.Operation} that, when run, associates this aggregator with 101 `iterator`. 102 """ 103 if not isinstance(iterator, iterator_ops.Iterator): 104 raise TypeError("`iterator` must be a `tf.data.Iterator` object.") 105 return gen_dataset_ops.iterator_set_stats_aggregator( 106 iterator._iterator_resource, self._resource) # pylint: disable=protected-access 107 108 109 def bytes_produced_stats(tag): 110 """Records the number of bytes produced by each element of the input dataset. 111 112 To consume the statistics, associate a `StatsAggregator` with an iterator 113 over the output dataset. 114 115 Args: 116 tag: String. All statistics recorded by the returned transformation will 117 be associated with the given `tag`. 118 119 Returns: 120 A `Dataset` transformation function, which can be passed to 121 @{tf.data.Dataset.apply}. 122 """ 123 124 def _apply_fn(dataset): 125 return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset, 126 tag) 127 128 return _apply_fn 129 130 131 def latency_stats(tag): 132 """Records the latency of producing each element of the input dataset. 133 134 To consume the statistics, associate a `StatsAggregator` with an iterator 135 over the output dataset. 136 137 Args: 138 tag: String. All statistics recorded by the returned transformation will 139 be associated with the given `tag`. 140 141 Returns: 142 A `Dataset` transformation function, which can be passed to 143 @{tf.data.Dataset.apply}. 144 """ 145 146 def _apply_fn(dataset): 147 return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag) 148 149 return _apply_fn 150 151 152 class _StatsDataset(dataset_ops.Dataset): 153 """A `Dataset` that acts as an identity, and also records statistics.""" 154 155 def __init__(self, input_dataset, op_function, tag): 156 super(_StatsDataset, self).__init__() 157 self._input_dataset = input_dataset 158 self._op_function = op_function 159 self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string) 160 161 def _as_variant_tensor(self): 162 return self._op_function( 163 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 164 self._tag, 165 output_types=nest.flatten( 166 sparse.as_dense_types(self.output_types, self.output_classes)), 167 output_shapes=nest.flatten( 168 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 169 170 @property 171 def output_shapes(self): 172 return self._input_dataset.output_shapes 173 174 @property 175 def output_types(self): 176 return self._input_dataset.output_types 177 178 @property 179 def output_classes(self): 180 return self._input_dataset.output_classes 181