1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Scan dataset transformation.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 import collections 21 22 from tensorflow.python.data.ops import dataset_ops 23 from tensorflow.python.data.util import nest 24 from tensorflow.python.data.util import sparse 25 from tensorflow.python.framework import function 26 from tensorflow.python.framework import ops 27 from tensorflow.python.ops import gen_dataset_ops 28 29 30 class _ScanDataset(dataset_ops.Dataset): 31 """A dataset that scans a function across its input.""" 32 33 def __init__(self, input_dataset, initial_state, scan_func): 34 """See `scan()` for details.""" 35 super(_ScanDataset, self).__init__() 36 self._input_dataset = input_dataset 37 38 with ops.name_scope("initial_state"): 39 self._initial_state = nest.pack_sequence_as(initial_state, [ 40 ops.convert_to_tensor(t, name="component_%d" % i) 41 for i, t in enumerate(nest.flatten(initial_state)) 42 ]) 43 44 # Compute initial values for the state shapes and types based on 45 # the initial state. These will be refined by running 46 # `tf_scan_func` one or more times below. 47 # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor. 48 self._state_shapes = nest.pack_sequence_as( 49 self._initial_state, 50 [t.shape for t in nest.flatten(self._initial_state)]) 51 self._state_types = nest.pack_sequence_as( 52 self._initial_state, 53 [t.dtype for t in nest.flatten(self._initial_state)]) 54 55 # Will be populated by calling `tf_scan_func`. 56 self._output_classes = None 57 self._output_shapes = None 58 self._output_types = None 59 60 # Iteratively rerun the scan function until reaching a fixed pont on 61 # `self._state_shapes`. 62 need_to_rerun = True 63 while need_to_rerun: 64 65 flat_state_shapes = nest.flatten(self._state_shapes) 66 flat_state_types = nest.flatten(self._state_types) 67 68 # Create a list in which `tf_scan_func` will store the s 69 flat_new_state_shapes = [] 70 71 @function.Defun(*(flat_state_types + nest.flatten( 72 sparse.as_dense_types(input_dataset.output_types, 73 input_dataset.output_classes)))) 74 def tf_scan_func(*args): 75 """A wrapper for Defun that facilitates shape inference.""" 76 # Pass in shape information from the state and input_dataset. 77 # TODO(b/69424092): Check that neither inputs nor outputs are sparse. 78 dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes, 79 input_dataset.output_classes) 80 for arg, shape in zip(args, 81 flat_state_shapes + nest.flatten(dense_shapes)): 82 arg.set_shape(shape) 83 84 pivot = len(flat_state_shapes) 85 old_state = nest.pack_sequence_as(self._initial_state, args[:pivot]) 86 input_value = nest.pack_sequence_as(input_dataset.output_types, 87 args[pivot:]) 88 89 ret = scan_func(old_state, input_value) 90 if not isinstance(ret, collections.Sequence) or len(ret) != 2: 91 raise TypeError("The scan function must return a pair comprising the " 92 "new state and the output value.") 93 new_state, output_value = ret 94 95 flat_new_state = [ 96 ops.convert_to_tensor(t) for t in nest.flatten(new_state) 97 ] 98 flat_output_value = [ 99 ops.convert_to_tensor(t) for t in nest.flatten(output_value) 100 ] 101 102 # Extract shape information from the returned values. 103 flat_new_state_shapes.extend([t.shape for t in flat_new_state]) 104 self._output_shapes = nest.pack_sequence_as( 105 output_value, [t.shape for t in flat_output_value]) 106 107 # Extract and validate type information from the returned values. 108 for t, dtype in zip(flat_new_state, flat_state_types): 109 if t.dtype != dtype: 110 raise TypeError( 111 "The element types for the new state must match the initial " 112 "state. Expected %s; got %s." % 113 (self._state_types, nest.pack_sequence_as( 114 self._state_types, [t.dtype for t in flat_new_state]))) 115 self._output_classes = nest.pack_sequence_as( 116 output_value, [ops.Tensor for _ in flat_output_value]) 117 self._output_types = nest.pack_sequence_as( 118 output_value, [t.dtype for t in flat_output_value]) 119 120 return flat_new_state + flat_output_value 121 122 # Use the private method that will execute `tf_scan_func` but delay 123 # adding it to the graph in case we need to rerun the function. 124 tf_scan_func._create_definition_if_needed() # pylint: disable=protected-access 125 126 weakened_state_shapes = [ 127 original.most_specific_compatible_shape(new) 128 for original, new in zip(flat_state_shapes, flat_new_state_shapes) 129 ] 130 131 need_to_rerun = False 132 for original_shape, weakened_shape in zip(flat_state_shapes, 133 weakened_state_shapes): 134 if original_shape.ndims is not None and ( 135 weakened_shape.ndims is None or 136 original_shape.as_list() != weakened_shape.as_list()): 137 need_to_rerun = True 138 break 139 140 if need_to_rerun: 141 # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun 142 # `tf_scan_func`. 143 self._state_shapes = nest.pack_sequence_as(self._state_shapes, 144 weakened_state_shapes) 145 146 self._scan_func = tf_scan_func 147 148 def _as_variant_tensor(self): 149 input_t = self._input_dataset._as_variant_tensor() # pylint: disable=protected-access 150 return gen_dataset_ops.scan_dataset( 151 input_t, 152 nest.flatten(self._initial_state), 153 self._scan_func.captured_inputs, 154 f=self._scan_func, 155 output_types=nest.flatten( 156 sparse.as_dense_types(self.output_types, self.output_classes)), 157 output_shapes=nest.flatten( 158 sparse.as_dense_shapes(self.output_shapes, self.output_classes))) 159 160 @property 161 def output_classes(self): 162 return self._output_classes 163 164 @property 165 def output_shapes(self): 166 return self._output_shapes 167 168 @property 169 def output_types(self): 170 return self._output_types 171 172 173 def scan(initial_state, scan_func): 174 """A transformation that scans a function across an input dataset. 175 176 This transformation is a stateful relative of @{tf.data.Dataset.map}. 177 In addition to mapping `scan_func` across the elements of the input dataset, 178 `scan()` accumulates one or more state tensors, whose initial values are 179 `initial_state`. 180 181 Args: 182 initial_state: A nested structure of tensors, representing the initial state 183 of the accumulator. 184 scan_func: A function that maps `(old_state, input_element)` to 185 `(new_state, output_element). It must take two arguments and return a 186 pair of nested structures of tensors. The `new_state` must match the 187 structure of `initial_state`. 188 189 Returns: 190 A `Dataset` transformation function, which can be passed to 191 @{tf.data.Dataset.apply}. 192 """ 193 def _apply_fn(dataset): 194 return _ScanDataset(dataset, initial_state, scan_func) 195 196 return _apply_fn 197