1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Seq2seq layer operations for use in neural networks.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import abc 22 import six 23 24 from tensorflow.python.framework import constant_op 25 from tensorflow.python.framework import dtypes 26 from tensorflow.python.framework import ops 27 from tensorflow.python.framework import tensor_shape 28 from tensorflow.python.framework import tensor_util 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import control_flow_ops 31 from tensorflow.python.ops import math_ops 32 from tensorflow.python.ops import rnn 33 from tensorflow.python.ops import tensor_array_ops 34 from tensorflow.python.ops import variable_scope 35 from tensorflow.python.util import nest 36 37 38 __all__ = ["Decoder", "dynamic_decode"] 39 40 41 _transpose_batch_time = rnn._transpose_batch_time # pylint: disable=protected-access 42 43 44 @six.add_metaclass(abc.ABCMeta) 45 class Decoder(object): 46 """An RNN Decoder abstract interface object. 47 48 Concepts used by this interface: 49 - `inputs`: (structure of) tensors and TensorArrays that is passed as input to 50 the RNNCell composing the decoder, at each time step. 51 - `state`: (structure of) tensors and TensorArrays that is passed to the 52 RNNCell instance as the state. 53 - `finished`: boolean tensor telling whether each sequence in the batch is 54 finished. 55 - `outputs`: Instance of BasicDecoderOutput. Result of the decoding, at each 56 time step. 57 """ 58 59 @property 60 def batch_size(self): 61 """The batch size of input values.""" 62 raise NotImplementedError 63 64 @property 65 def output_size(self): 66 """A (possibly nested tuple of...) integer[s] or `TensorShape` object[s].""" 67 raise NotImplementedError 68 69 @property 70 def output_dtype(self): 71 """A (possibly nested tuple of...) dtype[s].""" 72 raise NotImplementedError 73 74 @abc.abstractmethod 75 def initialize(self, name=None): 76 """Called before any decoding iterations. 77 78 This methods must compute initial input values and initial state. 79 80 Args: 81 name: Name scope for any created operations. 82 83 Returns: 84 `(finished, initial_inputs, initial_state)`: initial values of 85 'finished' flags, inputs and state. 86 """ 87 raise NotImplementedError 88 89 @abc.abstractmethod 90 def step(self, time, inputs, state, name=None): 91 """Called per step of decoding (but only once for dynamic decoding). 92 93 Args: 94 time: Scalar `int32` tensor. Current step number. 95 inputs: RNNCell input (possibly nested tuple of) tensor[s] for this time 96 step. 97 state: RNNCell state (possibly nested tuple of) tensor[s] from previous 98 time step. 99 name: Name scope for any created operations. 100 101 Returns: 102 `(outputs, next_state, next_inputs, finished)`: `outputs` is an object 103 containing the decoder output, `next_state` is a (structure of) state 104 tensors and TensorArrays, `next_inputs` is the tensor that should be used 105 as input for the next step, `finished` is a boolean tensor telling whether 106 the sequence is complete, for each sequence in the batch. 107 """ 108 raise NotImplementedError 109 110 def finalize(self, outputs, final_state, sequence_lengths): 111 raise NotImplementedError 112 113 @property 114 def tracks_own_finished(self): 115 """Describes whether the Decoder keeps track of finished states. 116 117 Most decoders will emit a true/false `finished` value independently 118 at each time step. In this case, the `dynamic_decode` function keeps track 119 of which batch entries are already finished, and performs a logical OR to 120 insert new batches to the finished set. 121 122 Some decoders, however, shuffle batches / beams between time steps and 123 `dynamic_decode` will mix up the finished state across these entries because 124 it does not track the reshuffle across time steps. In this case, it is 125 up to the decoder to declare that it will keep track of its own finished 126 state by setting this property to `True`. 127 128 Returns: 129 Python bool. 130 """ 131 return False 132 133 134 def _create_zero_outputs(size, dtype, batch_size): 135 """Create a zero outputs Tensor structure.""" 136 def _t(s): 137 return (s if isinstance(s, ops.Tensor) else constant_op.constant( 138 tensor_shape.TensorShape(s).as_list(), 139 dtype=dtypes.int32, 140 name="zero_suffix_shape")) 141 142 def _create(s, d): 143 return array_ops.zeros( 144 array_ops.concat( 145 ([batch_size], _t(s)), axis=0), dtype=d) 146 147 return nest.map_structure(_create, size, dtype) 148 149 150 def dynamic_decode(decoder, 151 output_time_major=False, 152 impute_finished=False, 153 maximum_iterations=None, 154 parallel_iterations=32, 155 swap_memory=False, 156 scope=None): 157 """Perform dynamic decoding with `decoder`. 158 159 Calls initialize() once and step() repeatedly on the Decoder object. 160 161 Args: 162 decoder: A `Decoder` instance. 163 output_time_major: Python boolean. Default: `False` (batch major). If 164 `True`, outputs are returned as time major tensors (this mode is faster). 165 Otherwise, outputs are returned as batch major tensors (this adds extra 166 time to the computation). 167 impute_finished: Python boolean. If `True`, then states for batch 168 entries which are marked as finished get copied through and the 169 corresponding outputs get zeroed out. This causes some slowdown at 170 each time step, but ensures that the final state and outputs have 171 the correct values and that backprop ignores time steps that were 172 marked as finished. 173 maximum_iterations: `int32` scalar, maximum allowed number of decoding 174 steps. Default is `None` (decode until the decoder is fully done). 175 parallel_iterations: Argument passed to `tf.while_loop`. 176 swap_memory: Argument passed to `tf.while_loop`. 177 scope: Optional variable scope to use. 178 179 Returns: 180 `(final_outputs, final_state, final_sequence_lengths)`. 181 182 Raises: 183 TypeError: if `decoder` is not an instance of `Decoder`. 184 ValueError: if `maximum_iterations` is provided but is not a scalar. 185 """ 186 if not isinstance(decoder, Decoder): 187 raise TypeError("Expected decoder to be type Decoder, but saw: %s" % 188 type(decoder)) 189 190 with variable_scope.variable_scope(scope, "decoder") as varscope: 191 # Properly cache variable values inside the while_loop 192 if varscope.caching_device is None: 193 varscope.set_caching_device(lambda op: op.device) 194 195 if maximum_iterations is not None: 196 maximum_iterations = ops.convert_to_tensor( 197 maximum_iterations, dtype=dtypes.int32, name="maximum_iterations") 198 if maximum_iterations.get_shape().ndims != 0: 199 raise ValueError("maximum_iterations must be a scalar") 200 201 initial_finished, initial_inputs, initial_state = decoder.initialize() 202 203 zero_outputs = _create_zero_outputs(decoder.output_size, 204 decoder.output_dtype, 205 decoder.batch_size) 206 207 if maximum_iterations is not None: 208 initial_finished = math_ops.logical_or( 209 initial_finished, 0 >= maximum_iterations) 210 initial_sequence_lengths = array_ops.zeros_like( 211 initial_finished, dtype=dtypes.int32) 212 initial_time = constant_op.constant(0, dtype=dtypes.int32) 213 214 def _shape(batch_size, from_shape): 215 if not isinstance(from_shape, tensor_shape.TensorShape): 216 return tensor_shape.TensorShape(None) 217 else: 218 batch_size = tensor_util.constant_value( 219 ops.convert_to_tensor( 220 batch_size, name="batch_size")) 221 return tensor_shape.TensorShape([batch_size]).concatenate(from_shape) 222 223 def _create_ta(s, d): 224 return tensor_array_ops.TensorArray( 225 dtype=d, 226 size=0, 227 dynamic_size=True, 228 element_shape=_shape(decoder.batch_size, s)) 229 230 initial_outputs_ta = nest.map_structure(_create_ta, decoder.output_size, 231 decoder.output_dtype) 232 233 def condition(unused_time, unused_outputs_ta, unused_state, unused_inputs, 234 finished, unused_sequence_lengths): 235 return math_ops.logical_not(math_ops.reduce_all(finished)) 236 237 def body(time, outputs_ta, state, inputs, finished, sequence_lengths): 238 """Internal while_loop body. 239 240 Args: 241 time: scalar int32 tensor. 242 outputs_ta: structure of TensorArray. 243 state: (structure of) state tensors and TensorArrays. 244 inputs: (structure of) input tensors. 245 finished: bool tensor (keeping track of what's finished). 246 sequence_lengths: int32 tensor (keeping track of time of finish). 247 248 Returns: 249 `(time + 1, outputs_ta, next_state, next_inputs, next_finished, 250 next_sequence_lengths)`. 251 ``` 252 """ 253 (next_outputs, decoder_state, next_inputs, 254 decoder_finished) = decoder.step(time, inputs, state) 255 if decoder.tracks_own_finished: 256 next_finished = decoder_finished 257 else: 258 next_finished = math_ops.logical_or(decoder_finished, finished) 259 if maximum_iterations is not None: 260 next_finished = math_ops.logical_or( 261 next_finished, time + 1 >= maximum_iterations) 262 next_sequence_lengths = array_ops.where( 263 math_ops.logical_and(math_ops.logical_not(finished), next_finished), 264 array_ops.fill(array_ops.shape(sequence_lengths), time + 1), 265 sequence_lengths) 266 267 nest.assert_same_structure(state, decoder_state) 268 nest.assert_same_structure(outputs_ta, next_outputs) 269 nest.assert_same_structure(inputs, next_inputs) 270 271 # Zero out output values past finish 272 if impute_finished: 273 emit = nest.map_structure( 274 lambda out, zero: array_ops.where(finished, zero, out), 275 next_outputs, 276 zero_outputs) 277 else: 278 emit = next_outputs 279 280 # Copy through states past finish 281 def _maybe_copy_state(new, cur): 282 # TensorArrays and scalar states get passed through. 283 if isinstance(cur, tensor_array_ops.TensorArray): 284 pass_through = True 285 else: 286 new.set_shape(cur.shape) 287 pass_through = (new.shape.ndims == 0) 288 return new if pass_through else array_ops.where(finished, cur, new) 289 290 if impute_finished: 291 next_state = nest.map_structure( 292 _maybe_copy_state, decoder_state, state) 293 else: 294 next_state = decoder_state 295 296 outputs_ta = nest.map_structure(lambda ta, out: ta.write(time, out), 297 outputs_ta, emit) 298 return (time + 1, outputs_ta, next_state, next_inputs, next_finished, 299 next_sequence_lengths) 300 301 res = control_flow_ops.while_loop( 302 condition, 303 body, 304 loop_vars=[ 305 initial_time, initial_outputs_ta, initial_state, initial_inputs, 306 initial_finished, initial_sequence_lengths, 307 ], 308 parallel_iterations=parallel_iterations, 309 swap_memory=swap_memory) 310 311 final_outputs_ta = res[1] 312 final_state = res[2] 313 final_sequence_lengths = res[5] 314 315 final_outputs = nest.map_structure(lambda ta: ta.stack(), final_outputs_ta) 316 317 try: 318 final_outputs, final_state = decoder.finalize( 319 final_outputs, final_state, final_sequence_lengths) 320 except NotImplementedError: 321 pass 322 323 if not output_time_major: 324 final_outputs = nest.map_structure(_transpose_batch_time, final_outputs) 325 326 return final_outputs, final_state, final_sequence_lengths 327