1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Grouping dataset transformations.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.data.util import nest 22 from tensorflow.python.data.util import sparse 23 from tensorflow.python.framework import dtypes 24 from tensorflow.python.framework import function 25 from tensorflow.python.framework import ops 26 from tensorflow.python.ops import gen_dataset_ops 27 28 29 def group_by_window(key_func, 30 reduce_func, 31 window_size=None, 32 window_size_func=None): 33 """A transformation that groups windows of elements by key and reduces them. 34 35 This transformation maps each consecutive element in a dataset to a key 36 using `key_func` and groups the elements by key. It then applies 37 `reduce_func` to at most `window_size_func(key)` elements matching the same 38 key. All execpt the final window for each key will contain 39 `window_size_func(key)` elements; the final window may be smaller. 40 41 You may provide either a constant `window_size` or a window size determined by 42 the key through `window_size_func`. 43 44 Args: 45 key_func: A function mapping a nested structure of tensors 46 (having shapes and types defined by `self.output_shapes` and 47 `self.output_types`) to a scalar `tf.int64` tensor. 48 reduce_func: A function mapping a key and a dataset of up to `window_size` 49 consecutive elements matching that key to another dataset. 50 window_size: A `tf.int64` scalar `tf.Tensor`, representing the number of 51 consecutive elements matching the same key to combine in a single 52 batch, which will be passed to `reduce_func`. Mutually exclusive with 53 `window_size_func`. 54 window_size_func: A function mapping a key to a `tf.int64` scalar 55 `tf.Tensor`, representing the number of consecutive elements matching 56 the same key to combine in a single batch, which will be passed to 57 `reduce_func`. Mutually exclusive with `window_size`. 58 59 Returns: 60 A `Dataset` transformation function, which can be passed to 61 @{tf.data.Dataset.apply}. 62 63 Raises: 64 ValueError: if neither or both of {`window_size`, `window_size_func`} are 65 passed. 66 """ 67 if (window_size is not None and window_size_func or 68 not (window_size is not None or window_size_func)): 69 raise ValueError("Must pass either window_size or window_size_func.") 70 71 if window_size is not None: 72 73 def constant_window_func(unused_key): 74 return ops.convert_to_tensor(window_size, dtype=dtypes.int64) 75 76 window_size_func = constant_window_func 77 78 assert window_size_func is not None 79 80 def _apply_fn(dataset): 81 """Function from `Dataset` to `Dataset` that applies the transformation.""" 82 return GroupByWindowDataset(dataset, key_func, reduce_func, 83 window_size_func) 84 85 return _apply_fn 86 87 88 class _VariantDataset(dataset_ops.Dataset): 89 """A Dataset wrapper for a tf.variant-typed function argument.""" 90 91 def __init__(self, dataset_variant, output_types, output_shapes, 92 output_classes): 93 super(_VariantDataset, self).__init__() 94 self._dataset_variant = dataset_variant 95 self._output_types = output_types 96 self._output_shapes = output_shapes 97 self._output_classes = output_classes 98 99 def _as_variant_tensor(self): 100 return self._dataset_variant 101 102 @property 103 def output_classes(self): 104 return self._output_classes 105 106 @property 107 def output_shapes(self): 108 return self._output_shapes 109 110 @property 111 def output_types(self): 112 return self._output_types 113 114 115 class GroupByWindowDataset(dataset_ops.Dataset): 116 """A `Dataset` that groups its input and performs a windowed reduction.""" 117 118 def __init__(self, input_dataset, key_func, reduce_func, window_size_func): 119 """See `group_by_window()` for details.""" 120 super(GroupByWindowDataset, self).__init__() 121 122 self._input_dataset = input_dataset 123 124 self._make_key_func(key_func, input_dataset) 125 self._make_reduce_func(reduce_func, input_dataset) 126 self._make_window_size_func(window_size_func) 127 128 def _make_window_size_func(self, window_size_func): 129 """Make wrapping Defun for window_size_func.""" 130 131 @function.Defun(dtypes.int64) 132 def tf_window_size_func(key): 133 key.set_shape([]) 134 window_size = ops.convert_to_tensor( 135 window_size_func(key), dtype=dtypes.int64) 136 if window_size.dtype != dtypes.int64: 137 raise ValueError( 138 "`window_size_func` must return a single tf.int64 tensor.") 139 return window_size 140 141 self._window_size_func = tf_window_size_func 142 self._window_size_func.add_to_graph(ops.get_default_graph()) 143 144 def _make_key_func(self, key_func, input_dataset): 145 """Make wrapping Defun for key_func.""" 146 147 @function.Defun(*nest.flatten( 148 sparse.as_dense_types(input_dataset.output_types, 149 input_dataset.output_classes))) 150 def tf_key_func(*args): 151 """A wrapper for Defun that facilitates shape inference.""" 152 # Pass in shape information from the input_dataset. 153 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 154 input_dataset.output_classes) 155 for arg, shape in zip(args, nest.flatten(dense_shapes)): 156 arg.set_shape(shape) 157 158 nested_args = nest.pack_sequence_as(input_dataset.output_types, args) 159 nested_args = sparse.deserialize_sparse_tensors( 160 nested_args, input_dataset.output_types, input_dataset.output_shapes, 161 input_dataset.output_classes) 162 # pylint: disable=protected-access 163 if dataset_ops._should_unpack_args(nested_args): 164 ret = key_func(*nested_args) 165 # pylint: enable=protected-access 166 else: 167 ret = key_func(nested_args) 168 ret = ops.convert_to_tensor(ret, dtype=dtypes.int64) 169 if ret.dtype != dtypes.int64: 170 raise ValueError("`key_func` must return a single tf.int64 tensor.") 171 return ret 172 173 self._key_func = tf_key_func 174 self._key_func.add_to_graph(ops.get_default_graph()) 175 176 def _make_reduce_func(self, reduce_func, input_dataset): 177 """Make wrapping Defun for reduce_func.""" 178 179 @function.Defun(dtypes.int64, dtypes.variant) 180 def tf_reduce_func(key, window_dataset_variant): 181 """A wrapper for Defun that facilitates shape inference.""" 182 key.set_shape([]) 183 window_dataset = _VariantDataset( 184 window_dataset_variant, input_dataset.output_types, 185 input_dataset.output_shapes, input_dataset.output_classes) 186 if not isinstance(window_dataset, dataset_ops.Dataset): 187 raise TypeError("`window_dataset` must return a `Dataset` object.") 188 output_dataset = reduce_func(key, window_dataset) 189 if not isinstance(output_dataset, dataset_ops.Dataset): 190 raise TypeError("`reduce_func` must return a `Dataset` object.") 191 self._output_classes = output_dataset.output_classes 192 self._output_types = output_dataset.output_types 193 self._output_shapes = output_dataset.output_shapes 194 return output_dataset._as_variant_tensor() # pylint: disable=protected-access 195 196 self._reduce_func = tf_reduce_func 197 self._reduce_func.add_to_graph(ops.get_default_graph()) 198 199 @property 200 def output_classes(self): 201 return self._output_classes 202 203 @property 204 def output_shapes(self): 205 return self._output_shapes 206 207 @property 208 def output_types(self): 209 return self._output_types 210 211 def _as_variant_tensor(self): 212 return gen_dataset_ops.group_by_window_dataset( 213 self._input_dataset._as_variant_tensor(), # pylint: disable=protected-access 214 self._key_func.captured_inputs, 215 self._reduce_func.captured_inputs, 216 self._window_size_func.captured_inputs, 217 key_func=self._key_func, 218 reduce_func=self._reduce_func, 219 window_size_func=self._window_size_func, 220 output_types=nest.flatten( 221 sparse.as_dense_types(self.output_types, self.output_classes)), 222 output_shapes=nest.flatten( 223 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 224