Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A class of Decoders that may sample to generate the next input.
     16 """
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import collections
     23 
     24 from tensorflow.contrib.seq2seq.python.ops import decoder
     25 from tensorflow.contrib.seq2seq.python.ops import helper as helper_py
     26 from tensorflow.python.framework import ops
     27 from tensorflow.python.framework import tensor_shape
     28 from tensorflow.python.layers import base as layers_base
     29 from tensorflow.python.ops import rnn_cell_impl
     30 from tensorflow.python.util import nest
     31 
     32 
     33 __all__ = [
     34     "BasicDecoderOutput",
     35     "BasicDecoder",
     36 ]
     37 
     38 
     39 class BasicDecoderOutput(
     40     collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))):
     41   pass
     42 
     43 
     44 class BasicDecoder(decoder.Decoder):
     45   """Basic sampling decoder."""
     46 
     47   def __init__(self, cell, helper, initial_state, output_layer=None):
     48     """Initialize BasicDecoder.
     49 
     50     Args:
     51       cell: An `RNNCell` instance.
     52       helper: A `Helper` instance.
     53       initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
     54         The initial state of the RNNCell.
     55       output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
     56         `tf.layers.Dense`. Optional layer to apply to the RNN output prior
     57         to storing the result or sampling.
     58 
     59     Raises:
     60       TypeError: if `cell`, `helper` or `output_layer` have an incorrect type.
     61     """
     62     if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
     63       raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
     64     if not isinstance(helper, helper_py.Helper):
     65       raise TypeError("helper must be a Helper, received: %s" % type(helper))
     66     if (output_layer is not None
     67         and not isinstance(output_layer, layers_base.Layer)):
     68       raise TypeError(
     69           "output_layer must be a Layer, received: %s" % type(output_layer))
     70     self._cell = cell
     71     self._helper = helper
     72     self._initial_state = initial_state
     73     self._output_layer = output_layer
     74 
     75   @property
     76   def batch_size(self):
     77     return self._helper.batch_size
     78 
     79   def _rnn_output_size(self):
     80     size = self._cell.output_size
     81     if self._output_layer is None:
     82       return size
     83     else:
     84       # To use layer's compute_output_shape, we need to convert the
     85       # RNNCell's output_size entries into shapes with an unknown
     86       # batch size.  We then pass this through the layer's
     87       # compute_output_shape and read off all but the first (batch)
     88       # dimensions to get the output size of the rnn with the layer
     89       # applied to the top.
     90       output_shape_with_unknown_batch = nest.map_structure(
     91           lambda s: tensor_shape.TensorShape([None]).concatenate(s),
     92           size)
     93       layer_output_shape = self._output_layer.compute_output_shape(
     94           output_shape_with_unknown_batch)
     95       return nest.map_structure(lambda s: s[1:], layer_output_shape)
     96 
     97   @property
     98   def output_size(self):
     99     # Return the cell output and the id
    100     return BasicDecoderOutput(
    101         rnn_output=self._rnn_output_size(),
    102         sample_id=self._helper.sample_ids_shape)
    103 
    104   @property
    105   def output_dtype(self):
    106     # Assume the dtype of the cell is the output_size structure
    107     # containing the input_state's first component's dtype.
    108     # Return that structure and the sample_ids_dtype from the helper.
    109     dtype = nest.flatten(self._initial_state)[0].dtype
    110     return BasicDecoderOutput(
    111         nest.map_structure(lambda _: dtype, self._rnn_output_size()),
    112         self._helper.sample_ids_dtype)
    113 
    114   def initialize(self, name=None):
    115     """Initialize the decoder.
    116 
    117     Args:
    118       name: Name scope for any created operations.
    119 
    120     Returns:
    121       `(finished, first_inputs, initial_state)`.
    122     """
    123     return self._helper.initialize() + (self._initial_state,)
    124 
    125   def step(self, time, inputs, state, name=None):
    126     """Perform a decoding step.
    127 
    128     Args:
    129       time: scalar `int32` tensor.
    130       inputs: A (structure of) input tensors.
    131       state: A (structure of) state tensors and TensorArrays.
    132       name: Name scope for any created operations.
    133 
    134     Returns:
    135       `(outputs, next_state, next_inputs, finished)`.
    136     """
    137     with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)):
    138       cell_outputs, cell_state = self._cell(inputs, state)
    139       if self._output_layer is not None:
    140         cell_outputs = self._output_layer(cell_outputs)
    141       sample_ids = self._helper.sample(
    142           time=time, outputs=cell_outputs, state=cell_state)
    143       (finished, next_inputs, next_state) = self._helper.next_inputs(
    144           time=time,
    145           outputs=cell_outputs,
    146           state=cell_state,
    147           sample_ids=sample_ids)
    148     outputs = BasicDecoderOutput(cell_outputs, sample_ids)
    149     return (outputs, next_state, next_inputs, finished)
    150