1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A class of Decoders that may sample to generate the next input. 16 """ 17 18 from __future__ import absolute_import 19 from __future__ import division 20 from __future__ import print_function 21 22 import collections 23 24 from tensorflow.contrib.seq2seq.python.ops import decoder 25 from tensorflow.contrib.seq2seq.python.ops import helper as helper_py 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import tensor_shape 28 from tensorflow.python.layers import base as layers_base 29 from tensorflow.python.ops import rnn_cell_impl 30 from tensorflow.python.util import nest 31 32 33 __all__ = [ 34 "BasicDecoderOutput", 35 "BasicDecoder", 36 ] 37 38 39 class BasicDecoderOutput( 40 collections.namedtuple("BasicDecoderOutput", ("rnn_output", "sample_id"))): 41 pass 42 43 44 class BasicDecoder(decoder.Decoder): 45 """Basic sampling decoder.""" 46 47 def __init__(self, cell, helper, initial_state, output_layer=None): 48 """Initialize BasicDecoder. 49 50 Args: 51 cell: An `RNNCell` instance. 52 helper: A `Helper` instance. 53 initial_state: A (possibly nested tuple of...) tensors and TensorArrays. 54 The initial state of the RNNCell. 55 output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., 56 `tf.layers.Dense`. Optional layer to apply to the RNN output prior 57 to storing the result or sampling. 58 59 Raises: 60 TypeError: if `cell`, `helper` or `output_layer` have an incorrect type. 61 """ 62 if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access 63 raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) 64 if not isinstance(helper, helper_py.Helper): 65 raise TypeError("helper must be a Helper, received: %s" % type(helper)) 66 if (output_layer is not None 67 and not isinstance(output_layer, layers_base.Layer)): 68 raise TypeError( 69 "output_layer must be a Layer, received: %s" % type(output_layer)) 70 self._cell = cell 71 self._helper = helper 72 self._initial_state = initial_state 73 self._output_layer = output_layer 74 75 @property 76 def batch_size(self): 77 return self._helper.batch_size 78 79 def _rnn_output_size(self): 80 size = self._cell.output_size 81 if self._output_layer is None: 82 return size 83 else: 84 # To use layer's compute_output_shape, we need to convert the 85 # RNNCell's output_size entries into shapes with an unknown 86 # batch size. We then pass this through the layer's 87 # compute_output_shape and read off all but the first (batch) 88 # dimensions to get the output size of the rnn with the layer 89 # applied to the top. 90 output_shape_with_unknown_batch = nest.map_structure( 91 lambda s: tensor_shape.TensorShape([None]).concatenate(s), 92 size) 93 layer_output_shape = self._output_layer.compute_output_shape( 94 output_shape_with_unknown_batch) 95 return nest.map_structure(lambda s: s[1:], layer_output_shape) 96 97 @property 98 def output_size(self): 99 # Return the cell output and the id 100 return BasicDecoderOutput( 101 rnn_output=self._rnn_output_size(), 102 sample_id=self._helper.sample_ids_shape) 103 104 @property 105 def output_dtype(self): 106 # Assume the dtype of the cell is the output_size structure 107 # containing the input_state's first component's dtype. 108 # Return that structure and the sample_ids_dtype from the helper. 109 dtype = nest.flatten(self._initial_state)[0].dtype 110 return BasicDecoderOutput( 111 nest.map_structure(lambda _: dtype, self._rnn_output_size()), 112 self._helper.sample_ids_dtype) 113 114 def initialize(self, name=None): 115 """Initialize the decoder. 116 117 Args: 118 name: Name scope for any created operations. 119 120 Returns: 121 `(finished, first_inputs, initial_state)`. 122 """ 123 return self._helper.initialize() + (self._initial_state,) 124 125 def step(self, time, inputs, state, name=None): 126 """Perform a decoding step. 127 128 Args: 129 time: scalar `int32` tensor. 130 inputs: A (structure of) input tensors. 131 state: A (structure of) state tensors and TensorArrays. 132 name: Name scope for any created operations. 133 134 Returns: 135 `(outputs, next_state, next_inputs, finished)`. 136 """ 137 with ops.name_scope(name, "BasicDecoderStep", (time, inputs, state)): 138 cell_outputs, cell_state = self._cell(inputs, state) 139 if self._output_layer is not None: 140 cell_outputs = self._output_layer(cell_outputs) 141 sample_ids = self._helper.sample( 142 time=time, outputs=cell_outputs, state=cell_state) 143 (finished, next_inputs, next_state) = self._helper.next_inputs( 144 time=time, 145 outputs=cell_outputs, 146 state=cell_state, 147 sample_ids=sample_ids) 148 outputs = BasicDecoderOutput(cell_outputs, sample_ids) 149 return (outputs, next_state, next_inputs, finished) 150