Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Seq2seq layer operations for use in neural networks."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import abc
     22 import six
     23 
     24 from tensorflow.python.framework import constant_op
     25 from tensorflow.python.framework import dtypes
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.framework import tensor_util
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import control_flow_ops
     31 from tensorflow.python.ops import math_ops
     32 from tensorflow.python.ops import rnn
     33 from tensorflow.python.ops import tensor_array_ops
     34 from tensorflow.python.ops import variable_scope
     35 from tensorflow.python.util import nest
     36 
     37 
     38 __all__ = ["Decoder", "dynamic_decode"]
     39 
     40 
     41 _transpose_batch_time = rnn._transpose_batch_time  # pylint: disable=protected-access
     42 
     43 
     44 @six.add_metaclass(abc.ABCMeta)
     45 class Decoder(object):
     46   """An RNN Decoder abstract interface object.
     47 
     48   Concepts used by this interface:
     49   - `inputs`: (structure of) tensors and TensorArrays that is passed as input to
     50     the RNNCell composing the decoder, at each time step.
     51   - `state`: (structure of) tensors and TensorArrays that is passed to the
     52     RNNCell instance as the state.
     53   - `finished`: boolean tensor telling whether each sequence in the batch is
     54     finished.
     55   - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each
     56     time step.
     57   """
     58 
     59   @property
     60   def batch_size(self):
     61     """The batch size of input values."""
     62     raise NotImplementedError
     63 
     64   @property
     65   def output_size(self):
     66     """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s]."""
     67     raise NotImplementedError
     68 
     69   @property
     70   def output_dtype(self):
     71     """A (possibly nested tuple of...) dtype[s]."""
     72     raise NotImplementedError
     73 
     74   @abc.abstractmethod
     75   def initialize(self, name=None):
     76     """Called before any decoding iterations.
     77 
     78     This methods must compute initial input values and initial state.
     79 
     80     Args:
     81       name: Name scope for any created operations.
     82 
     83     Returns:
     84       `(finished, initial_inputs, initial_state)`: initial values of
     85       'finished' flags, inputs and state.
     86     """
     87     raise NotImplementedError
     88 
     89   @abc.abstractmethod
     90   def step(self, time, inputs, state, name=None):
     91     """Called per step of decoding (but only once for dynamic decoding).
     92 
     93     Args:
     94       time: Scalar `int32` tensor. Current step number.
     95       inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time
     96         step.
     97       state: RNNCell state (possibly nested tuple of) tensor[s] from previous
     98         time step.
     99       name: Name scope for any created operations.
    100 
    101     Returns:
    102       `(outputs, next_state, next_inputs, finished)`: `outputs` is an object
    103       containing the decoder output, `next_state` is a (structure of) state
    104       tensors and TensorArrays, `next_inputs` is the tensor that should be used
    105       as input for the next step, `finished` is a boolean tensor telling whether
    106       the sequence is complete, for each sequence in the batch.
    107     """
    108     raise NotImplementedError
    109 
    110   def finalize(self, outputs, final_state, sequence_lengths):
    111     raise NotImplementedError
    112 
    113   @property
    114   def tracks_own_finished(self):
    115     """Describes whether the Decoder keeps track of finished states.
    116 
    117     Most decoders will emit a true/false `finished` value independently
    118     at each time step.  In this case, the `dynamic_decode` function keeps track
    119     of which batch entries are already finished, and performs a logical OR to
    120     insert new batches to the finished set.
    121 
    122     Some decoders, however, shuffle batches / beams between time steps and
    123     `dynamic_decode` will mix up the finished state across these entries because
    124     it does not track the reshuffle across time steps.  In this case, it is
    125     up to the decoder to declare that it will keep track of its own finished
    126     state by setting this property to `True`.
    127 
    128     Returns:
    129       Python bool.
    130     """
    131     return False
    132 
    133 
    134 def _create_zero_outputs(size, dtype, batch_size):
    135   """Create a zero outputs Tensor structure."""
    136   def _t(s):
    137     return (s if isinstance(s, ops.Tensor) else constant_op.constant(
    138         tensor_shape.TensorShape(s).as_list(),
    139         dtype=dtypes.int32,
    140         name="zero_suffix_shape"))
    141 
    142   def _create(s, d):
    143     return array_ops.zeros(
    144         array_ops.concat(
    145             ([batch_size], _t(s)), axis=0), dtype=d)
    146 
    147   return nest.map_structure(_create, size, dtype)
    148 
    149 
    150 def dynamic_decode(decoder,
    151                    output_time_major=False,
    152                    impute_finished=False,
    153                    maximum_iterations=None,
    154                    parallel_iterations=32,
    155                    swap_memory=False,
    156                    scope=None):
    157   """Perform dynamic decoding with `decoder`.
    158 
    159   Calls initialize() once and step() repeatedly on the Decoder object.
    160 
    161   Args:
    162     decoder: A `Decoder` instance.
    163     output_time_major: Python boolean.  Default: `False` (batch major).  If
    164       `True`, outputs are returned as time major tensors (this mode is faster).
    165       Otherwise, outputs are returned as batch major tensors (this adds extra
    166       time to the computation).
    167     impute_finished: Python boolean.  If `True`, then states for batch
    168       entries which are marked as finished get copied through and the
    169       corresponding outputs get zeroed out.  This causes some slowdown at
    170       each time step, but ensures that the final state and outputs have
    171       the correct values and that backprop ignores time steps that were
    172       marked as finished.
    173     maximum_iterations: `int32` scalar, maximum allowed number of decoding
    174        steps.  Default is `None` (decode until the decoder is fully done).
    175     parallel_iterations: Argument passed to `tf.while_loop`.
    176     swap_memory: Argument passed to `tf.while_loop`.
    177     scope: Optional variable scope to use.
    178 
    179   Returns:
    180     `(final_outputs, final_state, final_sequence_lengths)`.
    181 
    182   Raises:
    183     TypeError: if `decoder` is not an instance of `Decoder`.
    184     ValueError: if `maximum_iterations` is provided but is not a scalar.
    185   """
    186   if not isinstance(decoder, Decoder):
    187     raise TypeError("Expected decoder to be type Decoder, but saw: %s" %
    188                     type(decoder))
    189 
    190   with variable_scope.variable_scope(scope, "decoder") as varscope:
    191     # Properly cache variable values inside the while_loop
    192     if varscope.caching_device is None:
    193       varscope.set_caching_device(lambda op: op.device)
    194 
    195     if maximum_iterations is not None:
    196       maximum_iterations = ops.convert_to_tensor(
    197           maximum_iterations, dtype=dtypes.int32, name="maximum_iterations")
    198       if maximum_iterations.get_shape().ndims != 0:
    199         raise ValueError("maximum_iterations must be a scalar")
    200 
    201     initial_finished, initial_inputs, initial_state = decoder.initialize()
    202 
    203     zero_outputs = _create_zero_outputs(decoder.output_size,
    204                                         decoder.output_dtype,
    205                                         decoder.batch_size)
    206 
    207     if maximum_iterations is not None:
    208       initial_finished = math_ops.logical_or(
    209           initial_finished, 0 >= maximum_iterations)
    210     initial_sequence_lengths = array_ops.zeros_like(
    211         initial_finished, dtype=dtypes.int32)
    212     initial_time = constant_op.constant(0, dtype=dtypes.int32)
    213 
    214     def _shape(batch_size, from_shape):
    215       if not isinstance(from_shape, tensor_shape.TensorShape):
    216         return tensor_shape.TensorShape(None)
    217       else:
    218         batch_size = tensor_util.constant_value(
    219             ops.convert_to_tensor(
    220                 batch_size, name="batch_size"))
    221         return tensor_shape.TensorShape([batch_size]).concatenate(from_shape)
    222 
    223     def _create_ta(s, d):
    224       return tensor_array_ops.TensorArray(
    225           dtype=d,
    226           size=0,
    227           dynamic_size=True,
    228           element_shape=_shape(decoder.batch_size, s))
    229 
    230     initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size,
    231                                             decoder.output_dtype)
    232 
    233     def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs,
    234                   finished, unused_sequence_lengths):
    235       return math_ops.logical_not(math_ops.reduce_all(finished))
    236 
    237     def body(time, outputs_ta, state, inputs, finished, sequence_lengths):
    238       """Internal while_loop body.
    239 
    240       Args:
    241         time: scalar int32 tensor.
    242         outputs_ta: structure of TensorArray.
    243         state: (structure of) state tensors and TensorArrays.
    244         inputs: (structure of) input tensors.
    245         finished: bool tensor (keeping track of what's finished).
    246         sequence_lengths: int32 tensor (keeping track of time of finish).
    247 
    248       Returns:
    249         `(time + 1, outputs_ta, next_state, next_inputs, next_finished,
    250           next_sequence_lengths)`.
    251         ```
    252       """
    253       (next_outputs, decoder_state, next_inputs,
    254        decoder_finished) = decoder.step(time, inputs, state)
    255       if decoder.tracks_own_finished:
    256         next_finished = decoder_finished
    257       else:
    258         next_finished = math_ops.logical_or(decoder_finished, finished)
    259       if maximum_iterations is not None:
    260         next_finished = math_ops.logical_or(
    261             next_finished, time + 1 >= maximum_iterations)
    262       next_sequence_lengths = array_ops.where(
    263           math_ops.logical_and(math_ops.logical_not(finished), next_finished),
    264           array_ops.fill(array_ops.shape(sequence_lengths), time + 1),
    265           sequence_lengths)
    266 
    267       nest.assert_same_structure(state, decoder_state)
    268       nest.assert_same_structure(outputs_ta, next_outputs)
    269       nest.assert_same_structure(inputs, next_inputs)
    270 
    271       # Zero out output values past finish
    272       if impute_finished:
    273         emit = nest.map_structure(
    274             lambda out, zero: array_ops.where(finished, zero, out),
    275             next_outputs,
    276             zero_outputs)
    277       else:
    278         emit = next_outputs
    279 
    280       # Copy through states past finish
    281       def _maybe_copy_state(new, cur):
    282         # TensorArrays and scalar states get passed through.
    283         if isinstance(cur, tensor_array_ops.TensorArray):
    284           pass_through = True
    285         else:
    286           new.set_shape(cur.shape)
    287           pass_through = (new.shape.ndims == 0)
    288         return new if pass_through else array_ops.where(finished, cur, new)
    289 
    290       if impute_finished:
    291         next_state = nest.map_structure(
    292             _maybe_copy_state, decoder_state, state)
    293       else:
    294         next_state = decoder_state
    295 
    296       outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out),
    297                                       outputs_ta, emit)
    298       return (time + 1, outputs_ta, next_state, next_inputs, next_finished,
    299               next_sequence_lengths)
    300 
    301     res = control_flow_ops.while_loop(
    302         condition,
    303         body,
    304         loop_vars=[
    305             initial_time, initial_outputs_ta, initial_state, initial_inputs,
    306             initial_finished, initial_sequence_lengths,
    307         ],
    308         parallel_iterations=parallel_iterations,
    309         swap_memory=swap_memory)
    310 
    311     final_outputs_ta = res[1]
    312     final_state = res[2]
    313     final_sequence_lengths = res[5]
    314 
    315     final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta)
    316 
    317     try:
    318       final_outputs, final_state = decoder.finalize(
    319           final_outputs, final_state, final_sequence_lengths)
    320     except NotImplementedError:
    321       pass
    322 
    323     if not output_time_major:
    324       final_outputs = nest.map_structure(_transpose_batch_time, final_outputs)
    325 
    326   return final_outputs, final_state, final_sequence_lengths
    327