Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Experimental API for gathering statistics from `tf.data` pipelines."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.ops import iterator_ops
     22 from tensorflow.python.data.util import nest
     23 from tensorflow.python.data.util import sparse
     24 from tensorflow.python.framework import dtypes
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import gen_dataset_ops
     27 
     28 
     29 class StatsAggregator(object):
     30   """A stateful resource that aggregates statistics from one or more iterators.
     31 
     32   To record statistics, use one of the custom transformation functions defined
     33   in this module when defining your @{tf.data.Dataset}. All statistics will be
     34   aggregated by the `StatsAggregator` that is associated with a particular
     35   iterator (see below). For example, to record the total number of bytes
     36   produced by iterating over a dataset:
     37 
     38   ```python
     39   dataset = ...
     40   dataset = dataset.apply(stats_ops.bytes_produced_stats("total_bytes"))
     41   ```
     42 
     43   To associate a `StatsAggregator` with a @{tf.data.Iterator} object, use
     44   the following pattern:
     45 
     46   ```python
     47   dataset = ...
     48   iterator = dataset.make_one_shot_iterator()
     49   stats_aggregator = stats_ops.StatsAggregator()
     50   set_op = stats_op.set_stats_aggregator_op(iterator, stats_aggregator)
     51 
     52   with tf.Session() as sess:
     53     # Running `set_op` will associate `iterator` with `stats_aggregator`.
     54     sess.run(set_op)
     55   ```
     56 
     57   To get a protocol buffer summary of the currently aggregated statistics,
     58   use the `StatsAggregator.get_summary()` tensor. The easiest way to do this
     59   is to add the returned tensor to the @{tf.GraphKeys.SUMMARIES} collection,
     60   so that the summaries will be included with any existing summaries.
     61 
     62   ```python
     63   stats_aggregator = stats_ops.StatsAggregator()
     64   stats_summary = stats_aggregator.get_summary()
     65   tf.add_to_collection(tf.GraphKeys.SUMMARIES, stats_summary)
     66   ```
     67 
     68   Note: This interface is experimental and expected to change. In particular,
     69   we expect to add other implementations of `StatsAggregator` that provide
     70   different ways of exporting statistics, and add more types of statistics.
     71   """
     72 
     73   def __init__(self):
     74     """Creates a `StatsAggregator`."""
     75     self._resource = gen_dataset_ops.stats_aggregator_handle()
     76 
     77   def get_summary(self):
     78     """Returns a string @{tf.Tensor} that summarizes the aggregated statistics.
     79 
     80     The returned tensor will contain a serialized @{tf.summary.Summary} protocol
     81     buffer, which can be used with the standard TensorBoard logging facilities.
     82 
     83     Returns:
     84       A scalar string @{tf.Tensor} that summarizes the aggregated statistics.
     85     """
     86     return gen_dataset_ops.stats_aggregator_summary(self._resource)
     87 
     88   def subscribe(self, iterator):
     89     """Returns a @{tf.Operation} to associate this aggregator with `iterator`.
     90 
     91     Note: Each @{tf.data.Iterator} can be associated with at most one
     92     `StatsAggregator`. After running the operation that this function
     93     returns, all statistics recorded in the iteration of `iterator`
     94     will be stored in `stats_aggregator`.
     95 
     96     Args:
     97       iterator: A @{tf.data.Iterator} object.
     98 
     99     Returns:
    100       A @{tf.Operation} that, when run, associates this aggregator with
    101       `iterator`.
    102     """
    103     if not isinstance(iterator, iterator_ops.Iterator):
    104       raise TypeError("`iterator` must be a `tf.data.Iterator` object.")
    105     return gen_dataset_ops.iterator_set_stats_aggregator(
    106         iterator._iterator_resource, self._resource)  # pylint: disable=protected-access
    107 
    108 
    109 def bytes_produced_stats(tag):
    110   """Records the number of bytes produced by each element of the input dataset.
    111 
    112   To consume the statistics, associate a `StatsAggregator` with an iterator
    113   over the output dataset.
    114 
    115   Args:
    116     tag: String. All statistics recorded by the returned transformation will
    117       be associated with the given `tag`.
    118 
    119   Returns:
    120     A `Dataset` transformation function, which can be passed to
    121     @{tf.data.Dataset.apply}.
    122   """
    123 
    124   def _apply_fn(dataset):
    125     return _StatsDataset(dataset, gen_dataset_ops.bytes_produced_stats_dataset,
    126                          tag)
    127 
    128   return _apply_fn
    129 
    130 
    131 def latency_stats(tag):
    132   """Records the latency of producing each element of the input dataset.
    133 
    134   To consume the statistics, associate a `StatsAggregator` with an iterator
    135   over the output dataset.
    136 
    137   Args:
    138     tag: String. All statistics recorded by the returned transformation will
    139       be associated with the given `tag`.
    140 
    141   Returns:
    142     A `Dataset` transformation function, which can be passed to
    143     @{tf.data.Dataset.apply}.
    144   """
    145 
    146   def _apply_fn(dataset):
    147     return _StatsDataset(dataset, gen_dataset_ops.latency_stats_dataset, tag)
    148 
    149   return _apply_fn
    150 
    151 
    152 class _StatsDataset(dataset_ops.Dataset):
    153   """A `Dataset` that acts as an identity, and also records statistics."""
    154 
    155   def __init__(self, input_dataset, op_function, tag):
    156     super(_StatsDataset, self).__init__()
    157     self._input_dataset = input_dataset
    158     self._op_function = op_function
    159     self._tag = ops.convert_to_tensor(tag, dtype=dtypes.string)
    160 
    161   def _as_variant_tensor(self):
    162     return self._op_function(
    163         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
    164         self._tag,
    165         output_types=nest.flatten(
    166             sparse.as_dense_types(self.output_types, self.output_classes)),
    167         output_shapes=nest.flatten(
    168             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
    169 
    170   @property
    171   def output_shapes(self):
    172     return self._input_dataset.output_shapes
    173 
    174   @property
    175   def output_types(self):
    176     return self._input_dataset.output_types
    177 
    178   @property
    179   def output_classes(self):
    180     return self._input_dataset.output_classes
    181