Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Scan dataset transformation."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 import collections
     21 
     22 from tensorflow.python.data.ops import dataset_ops
     23 from tensorflow.python.data.util import nest
     24 from tensorflow.python.data.util import sparse
     25 from tensorflow.python.framework import function
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.ops import gen_dataset_ops
     28 
     29 
     30 class _ScanDataset(dataset_ops.Dataset):
     31   """A dataset that scans a function across its input."""
     32 
     33   def __init__(self, input_dataset, initial_state, scan_func):
     34     """See `scan()` for details."""
     35     super(_ScanDataset, self).__init__()
     36     self._input_dataset = input_dataset
     37 
     38     with ops.name_scope("initial_state"):
     39       self._initial_state = nest.pack_sequence_as(initial_state, [
     40           ops.convert_to_tensor(t, name="component_%d" % i)
     41           for i, t in enumerate(nest.flatten(initial_state))
     42       ])
     43 
     44     # Compute initial values for the state shapes and types based on
     45     # the initial state. These will be refined by running
     46     # `tf_scan_func` one or more times below.
     47     # TODO(b/68937811): Allow the initial state to be a tf.SparseTensor.
     48     self._state_shapes = nest.pack_sequence_as(
     49         self._initial_state,
     50         [t.shape for t in nest.flatten(self._initial_state)])
     51     self._state_types = nest.pack_sequence_as(
     52         self._initial_state,
     53         [t.dtype for t in nest.flatten(self._initial_state)])
     54 
     55     # Will be populated by calling `tf_scan_func`.
     56     self._output_classes = None
     57     self._output_shapes = None
     58     self._output_types = None
     59 
     60     # Iteratively rerun the scan function until reaching a fixed pont on
     61     # `self._state_shapes`.
     62     need_to_rerun = True
     63     while need_to_rerun:
     64 
     65       flat_state_shapes = nest.flatten(self._state_shapes)
     66       flat_state_types = nest.flatten(self._state_types)
     67 
     68       # Create a list in which `tf_scan_func` will store the s
     69       flat_new_state_shapes = []
     70 
     71       @function.Defun(*(flat_state_types + nest.flatten(
     72           sparse.as_dense_types(input_dataset.output_types,
     73                                 input_dataset.output_classes))))
     74       def tf_scan_func(*args):
     75         """A wrapper for Defun that facilitates shape inference."""
     76         # Pass in shape information from the state and input_dataset.
     77         # TODO(b/69424092): Check that neither inputs nor outputs are sparse.
     78         dense_shapes = sparse.as_dense_shapes(input_dataset.output_shapes,
     79                                               input_dataset.output_classes)
     80         for arg, shape in zip(args,
     81                               flat_state_shapes + nest.flatten(dense_shapes)):
     82           arg.set_shape(shape)
     83 
     84         pivot = len(flat_state_shapes)
     85         old_state = nest.pack_sequence_as(self._initial_state, args[:pivot])
     86         input_value = nest.pack_sequence_as(input_dataset.output_types,
     87                                             args[pivot:])
     88 
     89         ret = scan_func(old_state, input_value)
     90         if not isinstance(ret, collections.Sequence) or len(ret) != 2:
     91           raise TypeError("The scan function must return a pair comprising the "
     92                           "new state and the output value.")
     93         new_state, output_value = ret
     94 
     95         flat_new_state = [
     96             ops.convert_to_tensor(t) for t in nest.flatten(new_state)
     97         ]
     98         flat_output_value = [
     99             ops.convert_to_tensor(t) for t in nest.flatten(output_value)
    100         ]
    101 
    102         # Extract shape information from the returned values.
    103         flat_new_state_shapes.extend([t.shape for t in flat_new_state])
    104         self._output_shapes = nest.pack_sequence_as(
    105             output_value, [t.shape for t in flat_output_value])
    106 
    107         # Extract and validate type information from the returned values.
    108         for t, dtype in zip(flat_new_state, flat_state_types):
    109           if t.dtype != dtype:
    110             raise TypeError(
    111                 "The element types for the new state must match the initial "
    112                 "state. Expected %s; got %s." %
    113                 (self._state_types, nest.pack_sequence_as(
    114                     self._state_types, [t.dtype for t in flat_new_state])))
    115         self._output_classes = nest.pack_sequence_as(
    116             output_value, [ops.Tensor for _ in flat_output_value])
    117         self._output_types = nest.pack_sequence_as(
    118             output_value, [t.dtype for t in flat_output_value])
    119 
    120         return flat_new_state + flat_output_value
    121 
    122       # Use the private method that will execute `tf_scan_func` but delay
    123       # adding it to the graph in case we need to rerun the function.
    124       tf_scan_func._create_definition_if_needed()  # pylint: disable=protected-access
    125 
    126       weakened_state_shapes = [
    127           original.most_specific_compatible_shape(new)
    128           for original, new in zip(flat_state_shapes, flat_new_state_shapes)
    129       ]
    130 
    131       need_to_rerun = False
    132       for original_shape, weakened_shape in zip(flat_state_shapes,
    133                                                 weakened_state_shapes):
    134         if original_shape.ndims is not None and (
    135             weakened_shape.ndims is None or
    136             original_shape.as_list() != weakened_shape.as_list()):
    137           need_to_rerun = True
    138           break
    139 
    140       if need_to_rerun:
    141         # NOTE(mrry): `self._output_shapes` will be overwritten when we rerun
    142         # `tf_scan_func`.
    143         self._state_shapes = nest.pack_sequence_as(self._state_shapes,
    144                                                    weakened_state_shapes)
    145 
    146     self._scan_func = tf_scan_func
    147 
    148   def _as_variant_tensor(self):
    149     input_t = self._input_dataset._as_variant_tensor()  # pylint: disable=protected-access
    150     return gen_dataset_ops.scan_dataset(
    151         input_t,
    152         nest.flatten(self._initial_state),
    153         self._scan_func.captured_inputs,
    154         f=self._scan_func,
    155         output_types=nest.flatten(
    156             sparse.as_dense_types(self.output_types, self.output_classes)),
    157         output_shapes=nest.flatten(
    158             sparse.as_dense_shapes(self.output_shapes, self.output_classes)))
    159 
    160   @property
    161   def output_classes(self):
    162     return self._output_classes
    163 
    164   @property
    165   def output_shapes(self):
    166     return self._output_shapes
    167 
    168   @property
    169   def output_types(self):
    170     return self._output_types
    171 
    172 
    173 def scan(initial_state, scan_func):
    174   """A transformation that scans a function across an input dataset.
    175 
    176   This transformation is a stateful relative of @{tf.data.Dataset.map}.
    177   In addition to mapping `scan_func` across the elements of the input dataset,
    178   `scan()` accumulates one or more state tensors, whose initial values are
    179   `initial_state`.
    180 
    181   Args:
    182     initial_state: A nested structure of tensors, representing the initial state
    183       of the accumulator.
    184     scan_func: A function that maps `(old_state, input_element)` to
    185       `(new_state, output_element). It must take two arguments and return a
    186       pair of nested structures of tensors. The `new_state` must match the
    187       structure of `initial_state`.
    188 
    189   Returns:
    190     A `Dataset` transformation function, which can be passed to
    191     @{tf.data.Dataset.apply}.
    192   """
    193   def _apply_fn(dataset):
    194     return _ScanDataset(dataset, initial_state, scan_func)
    195 
    196   return _apply_fn
    197