Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Grouping dataset transformations."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.data.util import nest
     22 from tensorflow.python.data.util import sparse
     23 from tensorflow.python.framework import dtypes
     24 from tensorflow.python.framework import function
     25 from tensorflow.python.framework import ops
     26 from tensorflow.python.ops import gen_dataset_ops
     27 
     28 
     29 def group_by_window(key_func,
     30                     reduce_func,
     31                     window_size=None,
     32                     window_size_func=None):
     33   """A transformation that groups windows of elements by key and reduces them.
     34 
     35   This transformation maps each consecutive element in a dataset to a key
     36   using `key_func` and groups the elements by key. It then applies
     37   `reduce_func` to at most `window_size_func(key)` elements matching the same
     38   key. All execpt the final window for each key will contain
     39   `window_size_func(key)` elements; the final window may be smaller.
     40 
     41   You may provide either a constant `window_size` or a window size determined by
     42   the key through `window_size_func`.
     43 
     44   Args:
     45     key_func: A function mapping a nested structure of tensors
     46       (having shapes and types defined by `self.output_shapes` and
     47       `self.output_types`) to a scalar `tf.int64` tensor.
     48     reduce_func: A function mapping a key and a dataset of up to `window_size`
     49       consecutive elements matching that key to another dataset.
     50     window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of
     51       consecutive elements matching the same key to combine in a single
     52       batch, which will be passed to `reduce_func`. Mutually exclusive with
     53       `window_size_func`.
     54     window_size_func: A function mapping a key to a `tf.int64` scalar
     55       `tf.Tensor`, representing the number of consecutive elements matching
     56       the same key to combine in a single batch, which will be passed to
     57       `reduce_func`. Mutually exclusive with `window_size`.
     58 
     59   Returns:
     60     A `Dataset` transformation function, which can be passed to
     61     @{tf.data.Dataset.apply}.
     62 
     63   Raises:
     64     ValueError: if neither or both of {`window_size`, `window_size_func`} are
     65       passed.
     66   """
     67   if (window_size is not None and window_size_func or
     68       not (window_size is not None or window_size_func)):
     69     raise ValueError("Must pass either window_size or window_size_func.")
     70 
     71   if window_size is not None:
     72 
     73     def constant_window_func(unused_key):
     74       return ops.convert_to_tensor(window_size, dtype=dtypes.int64)
     75 
     76     window_size_func = constant_window_func
     77 
     78   assert window_size_func is not None
     79 
     80   def _apply_fn(dataset):
     81     """Function from `Dataset` to `Dataset` that applies the transformation."""
     82     return GroupByWindowDataset(dataset, key_func, reduce_func,
     83                                 window_size_func)
     84 
     85   return _apply_fn
     86 
     87 
     88 class _VariantDataset(dataset_ops.Dataset):
     89   """A Dataset wrapper for a tf.variant-typed function argument."""
     90 
     91   def __init__(self, dataset_variant, output_types, output_shapes,
     92                output_classes):
     93     super(_VariantDataset, self).__init__()
     94     self._dataset_variant = dataset_variant
     95     self._output_types = output_types
     96     self._output_shapes = output_shapes
     97     self._output_classes = output_classes
     98 
     99   def _as_variant_tensor(self):
    100     return self._dataset_variant
    101 
    102   @property
    103   def output_classes(self):
    104     return self._output_classes
    105 
    106   @property
    107   def output_shapes(self):
    108     return self._output_shapes
    109 
    110   @property
    111   def output_types(self):
    112     return self._output_types
    113 
    114 
    115 class GroupByWindowDataset(dataset_ops.Dataset):
    116   """A `Dataset` that groups its input and performs a windowed reduction."""
    117 
    118   def __init__(self, input_dataset, key_func, reduce_func, window_size_func):
    119     """See `group_by_window()` for details."""
    120     super(GroupByWindowDataset, self).__init__()
    121 
    122     self._input_dataset = input_dataset
    123 
    124     self._make_key_func(key_func, input_dataset)
    125     self._make_reduce_func(reduce_func, input_dataset)
    126     self._make_window_size_func(window_size_func)
    127 
    128   def _make_window_size_func(self, window_size_func):
    129     """Make wrapping Defun for window_size_func."""
    130 
    131     @function.Defun(dtypes.int64)
    132     def tf_window_size_func(key):
    133       key.set_shape([])
    134       window_size = ops.convert_to_tensor(
    135           window_size_func(key), dtype=dtypes.int64)
    136       if window_size.dtype != dtypes.int64:
    137         raise ValueError(
    138             "`window_size_func` must return a single tf.int64 tensor.")
    139       return window_size
    140 
    141     self._window_size_func = tf_window_size_func
    142     self._window_size_func.add_to_graph(ops.get_default_graph())
    143 
    144   def _make_key_func(self, key_func, input_dataset):
    145     """Make wrapping Defun for key_func."""
    146 
    147     @function.Defun(*nest.flatten(
    148         sparse.as_dense_types(input_dataset.output_types,
    149                               input_dataset.output_classes)))
    150     def tf_key_func(*args):
    151       """A wrapper for Defun that facilitates shape inference."""
    152       # Pass in shape information from the input_dataset.
    153       dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
    154                                             input_dataset.output_classes)
    155       for arg, shape in zip(args, nest.flatten(dense_shapes)):
    156         arg.set_shape(shape)
    157 
    158       nested_args = nest.pack_sequence_as(input_dataset.output_types, args)
    159       nested_args = sparse.deserialize_sparse_tensors(
    160           nested_args, input_dataset.output_types, input_dataset.output_shapes,
    161           input_dataset.output_classes)
    162       # pylint: disable=protected-access
    163       if dataset_ops._should_unpack_args(nested_args):
    164         ret = key_func(*nested_args)
    165       # pylint: enable=protected-access
    166       else:
    167         ret = key_func(nested_args)
    168       ret = ops.convert_to_tensor(ret, dtype=dtypes.int64)
    169       if ret.dtype != dtypes.int64:
    170         raise ValueError("`key_func` must return a single tf.int64 tensor.")
    171       return ret
    172 
    173     self._key_func = tf_key_func
    174     self._key_func.add_to_graph(ops.get_default_graph())
    175 
    176   def _make_reduce_func(self, reduce_func, input_dataset):
    177     """Make wrapping Defun for reduce_func."""
    178 
    179     @function.Defun(dtypes.int64, dtypes.variant)
    180     def tf_reduce_func(key, window_dataset_variant):
    181       """A wrapper for Defun that facilitates shape inference."""
    182       key.set_shape([])
    183       window_dataset = _VariantDataset(
    184           window_dataset_variant, input_dataset.output_types,
    185           input_dataset.output_shapes, input_dataset.output_classes)
    186       if not isinstance(window_dataset, dataset_ops.Dataset):
    187         raise TypeError("`window_dataset` must return a `Dataset` object.")
    188       output_dataset = reduce_func(key, window_dataset)
    189       if not isinstance(output_dataset, dataset_ops.Dataset):
    190         raise TypeError("`reduce_func` must return a `Dataset` object.")
    191       self._output_classes = output_dataset.output_classes
    192       self._output_types = output_dataset.output_types
    193       self._output_shapes = output_dataset.output_shapes
    194       return output_dataset._as_variant_tensor()  # pylint: disable=protected-access
    195 
    196     self._reduce_func = tf_reduce_func
    197     self._reduce_func.add_to_graph(ops.get_default_graph())
    198 
    199   @property
    200   def output_classes(self):
    201     return self._output_classes
    202 
    203   @property
    204   def output_shapes(self):
    205     return self._output_shapes
    206 
    207   @property
    208   def output_types(self):
    209     return self._output_types
    210 
    211   def _as_variant_tensor(self):
    212     return gen_dataset_ops.group_by_window_dataset(
    213         self._input_dataset._as_variant_tensor(),  # pylint: disable=protected-access
    214         self._key_func.captured_inputs,
    215         self._reduce_func.captured_inputs,
    216         self._window_size_func.captured_inputs,
    217         key_func=self._key_func,
    218         reduce_func=self._reduce_func,
    219         window_size_func=self._window_size_func,
    220         output_types=nest.flatten(
    221             sparse.as_dense_types(self.output_types, self.output_classes)),
    222         output_shapes=nest.flatten(
    223             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
    224