1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """A decoder that performs beam search.""" 16 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 21 import collections 22 import numpy as np 23 24 from tensorflow.contrib.seq2seq.python.ops import beam_search_ops 25 from tensorflow.contrib.seq2seq.python.ops import decoder 26 from tensorflow.python.framework import dtypes 27 from tensorflow.python.framework import ops 28 from tensorflow.python.framework import tensor_shape 29 from tensorflow.python.framework import tensor_util 30 from tensorflow.python.layers import base as layers_base 31 from tensorflow.python.ops import array_ops 32 from tensorflow.python.ops import control_flow_ops 33 from tensorflow.python.ops import embedding_ops 34 from tensorflow.python.ops import math_ops 35 from tensorflow.python.ops import nn_ops 36 from tensorflow.python.ops import rnn_cell_impl 37 from tensorflow.python.ops import tensor_array_ops 38 from tensorflow.python.util import nest 39 40 __all__ = [ 41 "BeamSearchDecoderOutput", 42 "BeamSearchDecoderState", 43 "BeamSearchDecoder", 44 "FinalBeamSearchDecoderOutput", 45 "tile_batch", 46 ] 47 48 49 class BeamSearchDecoderState( 50 collections.namedtuple("BeamSearchDecoderState", 51 ("cell_state", "log_probs", "finished", "lengths"))): 52 pass 53 54 55 class BeamSearchDecoderOutput( 56 collections.namedtuple("BeamSearchDecoderOutput", 57 ("scores", "predicted_ids", "parent_ids"))): 58 pass 59 60 61 class FinalBeamSearchDecoderOutput( 62 collections.namedtuple("FinalBeamDecoderOutput", 63 ["predicted_ids", "beam_search_decoder_output"])): 64 """Final outputs returned by the beam search after all decoding is finished. 65 66 Args: 67 predicted_ids: The final prediction. A tensor of shape 68 `[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if 69 `output_time_major` is True). Beams are ordered from best to worst. 70 beam_search_decoder_output: An instance of `BeamSearchDecoderOutput` that 71 describes the state of the beam search. 72 """ 73 pass 74 75 76 def _tile_batch(t, multiplier): 77 """Core single-tensor implementation of tile_batch.""" 78 t = ops.convert_to_tensor(t, name="t") 79 shape_t = array_ops.shape(t) 80 if t.shape.ndims is None or t.shape.ndims < 1: 81 raise ValueError("t must have statically known rank") 82 tiling = [1] * (t.shape.ndims + 1) 83 tiling[1] = multiplier 84 tiled_static_batch_size = ( 85 t.shape[0].value * multiplier if t.shape[0].value is not None else None) 86 tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling) 87 tiled = array_ops.reshape(tiled, 88 array_ops.concat( 89 ([shape_t[0] * multiplier], shape_t[1:]), 0)) 90 tiled.set_shape( 91 tensor_shape.TensorShape([tiled_static_batch_size]).concatenate( 92 t.shape[1:])) 93 return tiled 94 95 96 def tile_batch(t, multiplier, name=None): 97 """Tile the batch dimension of a (possibly nested structure of) tensor(s) t. 98 99 For each tensor t in a (possibly nested structure) of tensors, 100 this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of 101 minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape 102 `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries 103 `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated 104 `multiplier` times. 105 106 Args: 107 t: `Tensor` shaped `[batch_size, ...]`. 108 multiplier: Python int. 109 name: Name scope for any created operations. 110 111 Returns: 112 A (possibly nested structure of) `Tensor` shaped 113 `[batch_size * multiplier, ...]`. 114 115 Raises: 116 ValueError: if tensor(s) `t` do not have a statically known rank or 117 the rank is < 1. 118 """ 119 flat_t = nest.flatten(t) 120 with ops.name_scope(name, "tile_batch", flat_t + [multiplier]): 121 return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t) 122 123 124 def _check_maybe(t): 125 if isinstance(t, tensor_array_ops.TensorArray): 126 raise TypeError( 127 "TensorArray state is not supported by BeamSearchDecoder: %s" % t.name) 128 if t.shape.ndims is None: 129 raise ValueError( 130 "Expected tensor (%s) to have known rank, but ndims == None." % t) 131 132 133 class BeamSearchDecoder(decoder.Decoder): 134 """BeamSearch sampling decoder. 135 136 **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in 137 `AttentionWrapper`, then you must ensure that: 138 139 - The encoder output has been tiled to `beam_width` via 140 @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`). 141 - The `batch_size` argument passed to the `zero_state` method of this 142 wrapper is equal to `true_batch_size * beam_width`. 143 - The initial state created with `zero_state` above contains a 144 `cell_state` value containing properly tiled final state from the 145 encoder. 146 147 An example: 148 149 ``` 150 tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch( 151 encoder_outputs, multiplier=beam_width) 152 tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch( 153 encoder_final_state, multiplier=beam_width) 154 tiled_sequence_length = tf.contrib.seq2seq.tile_batch( 155 sequence_length, multiplier=beam_width) 156 attention_mechanism = MyFavoriteAttentionMechanism( 157 num_units=attention_depth, 158 memory=tiled_inputs, 159 memory_sequence_length=tiled_sequence_length) 160 attention_cell = AttentionWrapper(cell, attention_mechanism, ...) 161 decoder_initial_state = attention_cell.zero_state( 162 dtype, batch_size=true_batch_size * beam_width) 163 decoder_initial_state = decoder_initial_state.clone( 164 cell_state=tiled_encoder_final_state) 165 ``` 166 """ 167 168 def __init__(self, 169 cell, 170 embedding, 171 start_tokens, 172 end_token, 173 initial_state, 174 beam_width, 175 output_layer=None, 176 length_penalty_weight=0.0): 177 """Initialize the BeamSearchDecoder. 178 179 Args: 180 cell: An `RNNCell` instance. 181 embedding: A callable that takes a vector tensor of `ids` (argmax ids), 182 or the `params` argument for `embedding_lookup`. 183 start_tokens: `int32` vector shaped `[batch_size]`, the start tokens. 184 end_token: `int32` scalar, the token that marks end of decoding. 185 initial_state: A (possibly nested tuple of...) tensors and TensorArrays. 186 beam_width: Python integer, the number of beams. 187 output_layer: (Optional) An instance of `tf.layers.Layer`, i.e., 188 `tf.layers.Dense`. Optional layer to apply to the RNN output prior 189 to storing the result or sampling. 190 length_penalty_weight: Float weight to penalize length. Disabled with 0.0. 191 192 Raises: 193 TypeError: if `cell` is not an instance of `RNNCell`, 194 or `output_layer` is not an instance of `tf.layers.Layer`. 195 ValueError: If `start_tokens` is not a vector or 196 `end_token` is not a scalar. 197 """ 198 if not rnn_cell_impl._like_rnncell(cell): # pylint: disable=protected-access 199 raise TypeError("cell must be an RNNCell, received: %s" % type(cell)) 200 if (output_layer is not None and 201 not isinstance(output_layer, layers_base.Layer)): 202 raise TypeError( 203 "output_layer must be a Layer, received: %s" % type(output_layer)) 204 self._cell = cell 205 self._output_layer = output_layer 206 207 if callable(embedding): 208 self._embedding_fn = embedding 209 else: 210 self._embedding_fn = ( 211 lambda ids: embedding_ops.embedding_lookup(embedding, ids)) 212 213 self._start_tokens = ops.convert_to_tensor( 214 start_tokens, dtype=dtypes.int32, name="start_tokens") 215 if self._start_tokens.get_shape().ndims != 1: 216 raise ValueError("start_tokens must be a vector") 217 self._end_token = ops.convert_to_tensor( 218 end_token, dtype=dtypes.int32, name="end_token") 219 if self._end_token.get_shape().ndims != 0: 220 raise ValueError("end_token must be a scalar") 221 222 self._batch_size = array_ops.size(start_tokens) 223 self._beam_width = beam_width 224 self._length_penalty_weight = length_penalty_weight 225 self._initial_cell_state = nest.map_structure( 226 self._maybe_split_batch_beams, initial_state, self._cell.state_size) 227 self._start_tokens = array_ops.tile( 228 array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width]) 229 self._start_inputs = self._embedding_fn(self._start_tokens) 230 231 self._finished = array_ops.one_hot( 232 array_ops.zeros([self._batch_size], dtype=dtypes.int32), 233 depth=self._beam_width, 234 on_value=False, 235 off_value=True, 236 dtype=dtypes.bool) 237 238 @property 239 def batch_size(self): 240 return self._batch_size 241 242 def _rnn_output_size(self): 243 size = self._cell.output_size 244 if self._output_layer is None: 245 return size 246 else: 247 # To use layer's compute_output_shape, we need to convert the 248 # RNNCell's output_size entries into shapes with an unknown 249 # batch size. We then pass this through the layer's 250 # compute_output_shape and read off all but the first (batch) 251 # dimensions to get the output size of the rnn with the layer 252 # applied to the top. 253 output_shape_with_unknown_batch = nest.map_structure( 254 lambda s: tensor_shape.TensorShape([None]).concatenate(s), size) 255 layer_output_shape = self._output_layer.compute_output_shape( 256 output_shape_with_unknown_batch) 257 return nest.map_structure(lambda s: s[1:], layer_output_shape) 258 259 @property 260 def tracks_own_finished(self): 261 """The BeamSearchDecoder shuffles its beams and their finished state. 262 263 For this reason, it conflicts with the `dynamic_decode` function's 264 tracking of finished states. Setting this property to true avoids 265 early stopping of decoding due to mismanagement of the finished state 266 in `dynamic_decode`. 267 268 Returns: 269 `True`. 270 """ 271 return True 272 273 @property 274 def output_size(self): 275 # Return the cell output and the id 276 return BeamSearchDecoderOutput( 277 scores=tensor_shape.TensorShape([self._beam_width]), 278 predicted_ids=tensor_shape.TensorShape([self._beam_width]), 279 parent_ids=tensor_shape.TensorShape([self._beam_width])) 280 281 @property 282 def output_dtype(self): 283 # Assume the dtype of the cell is the output_size structure 284 # containing the input_state's first component's dtype. 285 # Return that structure and int32 (the id) 286 dtype = nest.flatten(self._initial_cell_state)[0].dtype 287 return BeamSearchDecoderOutput( 288 scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()), 289 predicted_ids=dtypes.int32, 290 parent_ids=dtypes.int32) 291 292 def initialize(self, name=None): 293 """Initialize the decoder. 294 295 Args: 296 name: Name scope for any created operations. 297 298 Returns: 299 `(finished, start_inputs, initial_state)`. 300 """ 301 finished, start_inputs = self._finished, self._start_inputs 302 303 log_probs = array_ops.one_hot( # shape(batch_sz, beam_sz) 304 array_ops.zeros([self._batch_size], dtype=dtypes.int32), 305 depth=self._beam_width, 306 on_value=0.0, 307 off_value=-np.Inf, 308 dtype=nest.flatten(self._initial_cell_state)[0].dtype) 309 310 initial_state = BeamSearchDecoderState( 311 cell_state=self._initial_cell_state, 312 log_probs=log_probs, 313 finished=finished, 314 lengths=array_ops.zeros( 315 [self._batch_size, self._beam_width], dtype=dtypes.int64)) 316 317 return (finished, start_inputs, initial_state) 318 319 def finalize(self, outputs, final_state, sequence_lengths): 320 """Finalize and return the predicted_ids. 321 322 Args: 323 outputs: An instance of BeamSearchDecoderOutput. 324 final_state: An instance of BeamSearchDecoderState. Passed through to the 325 output. 326 sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`. 327 The sequence lengths determined for each beam during decode. 328 **NOTE** These are ignored; the updated sequence lengths are stored in 329 `final_state.lengths`. 330 331 Returns: 332 outputs: An instance of `FinalBeamSearchDecoderOutput` where the 333 predicted_ids are the result of calling _gather_tree. 334 final_state: The same input instance of `BeamSearchDecoderState`. 335 """ 336 del sequence_lengths 337 # Get max_sequence_length across all beams for each batch. 338 max_sequence_lengths = math_ops.to_int32( 339 math_ops.reduce_max(final_state.lengths, axis=1)) 340 predicted_ids = beam_search_ops.gather_tree( 341 outputs.predicted_ids, 342 outputs.parent_ids, 343 max_sequence_lengths=max_sequence_lengths, 344 end_token=self._end_token) 345 outputs = FinalBeamSearchDecoderOutput( 346 beam_search_decoder_output=outputs, predicted_ids=predicted_ids) 347 return outputs, final_state 348 349 def _merge_batch_beams(self, t, s=None): 350 """Merges the tensor from a batch of beams into a batch by beams. 351 352 More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We 353 reshape this into [batch_size*beam_width, s] 354 355 Args: 356 t: Tensor of dimension [batch_size, beam_width, s] 357 s: (Possibly known) depth shape. 358 359 Returns: 360 A reshaped version of t with dimension [batch_size * beam_width, s]. 361 """ 362 if isinstance(s, ops.Tensor): 363 s = tensor_shape.as_shape(tensor_util.constant_value(s)) 364 else: 365 s = tensor_shape.TensorShape(s) 366 t_shape = array_ops.shape(t) 367 static_batch_size = tensor_util.constant_value(self._batch_size) 368 batch_size_beam_width = ( 369 None 370 if static_batch_size is None else static_batch_size * self._beam_width) 371 reshaped_t = array_ops.reshape( 372 t, 373 array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]), 374 0)) 375 reshaped_t.set_shape( 376 (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s))) 377 return reshaped_t 378 379 def _split_batch_beams(self, t, s=None): 380 """Splits the tensor from a batch by beams into a batch of beams. 381 382 More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We 383 reshape this into [batch_size, beam_width, s] 384 385 Args: 386 t: Tensor of dimension [batch_size*beam_width, s]. 387 s: (Possibly known) depth shape. 388 389 Returns: 390 A reshaped version of t with dimension [batch_size, beam_width, s]. 391 392 Raises: 393 ValueError: If, after reshaping, the new tensor is not shaped 394 `[batch_size, beam_width, s]` (assuming batch_size and beam_width 395 are known statically). 396 """ 397 if isinstance(s, ops.Tensor): 398 s = tensor_shape.TensorShape(tensor_util.constant_value(s)) 399 else: 400 s = tensor_shape.TensorShape(s) 401 t_shape = array_ops.shape(t) 402 reshaped_t = array_ops.reshape( 403 t, 404 array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]), 405 0)) 406 static_batch_size = tensor_util.constant_value(self._batch_size) 407 expected_reshaped_shape = tensor_shape.TensorShape( 408 [static_batch_size, self._beam_width]).concatenate(s) 409 if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape): 410 raise ValueError("Unexpected behavior when reshaping between beam width " 411 "and batch size. The reshaped tensor has shape: %s. " 412 "We expected it to have shape " 413 "(batch_size, beam_width, depth) == %s. Perhaps you " 414 "forgot to create a zero_state with " 415 "batch_size=encoder_batch_size * beam_width?" % 416 (reshaped_t.shape, expected_reshaped_shape)) 417 reshaped_t.set_shape(expected_reshaped_shape) 418 return reshaped_t 419 420 def _maybe_split_batch_beams(self, t, s): 421 """Maybe splits the tensor from a batch by beams into a batch of beams. 422 423 We do this so that we can use nest and not run into problems with shapes. 424 425 Args: 426 t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`. 427 s: `Tensor`, Python int, or `TensorShape`. 428 429 Returns: 430 If `t` is a matrix or higher order tensor, then the return value is 431 `t` reshaped to `[batch_size, beam_width] + s`. Otherwise `t` is 432 returned unchanged. 433 434 Raises: 435 TypeError: If `t` is an instance of `TensorArray`. 436 ValueError: If the rank of `t` is not statically known. 437 """ 438 _check_maybe(t) 439 if t.shape.ndims >= 1: 440 return self._split_batch_beams(t, s) 441 else: 442 return t 443 444 def _maybe_merge_batch_beams(self, t, s): 445 """Splits the tensor from a batch by beams into a batch of beams. 446 447 More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`, 448 then we reshape it to `[batch_size, beam_width] + s`. 449 450 Args: 451 t: `Tensor` of dimension `[batch_size * beam_width] + s`. 452 s: `Tensor`, Python int, or `TensorShape`. 453 454 Returns: 455 A reshaped version of t with shape `[batch_size, beam_width] + s`. 456 457 Raises: 458 TypeError: If `t` is an instance of `TensorArray`. 459 ValueError: If the rank of `t` is not statically known. 460 """ 461 _check_maybe(t) 462 if t.shape.ndims >= 2: 463 return self._merge_batch_beams(t, s) 464 else: 465 return t 466 467 def step(self, time, inputs, state, name=None): 468 """Perform a decoding step. 469 470 Args: 471 time: scalar `int32` tensor. 472 inputs: A (structure of) input tensors. 473 state: A (structure of) state tensors and TensorArrays. 474 name: Name scope for any created operations. 475 476 Returns: 477 `(outputs, next_state, next_inputs, finished)`. 478 """ 479 batch_size = self._batch_size 480 beam_width = self._beam_width 481 end_token = self._end_token 482 length_penalty_weight = self._length_penalty_weight 483 484 with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)): 485 cell_state = state.cell_state 486 inputs = nest.map_structure( 487 lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs) 488 cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state, 489 self._cell.state_size) 490 cell_outputs, next_cell_state = self._cell(inputs, cell_state) 491 cell_outputs = nest.map_structure( 492 lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs) 493 next_cell_state = nest.map_structure( 494 self._maybe_split_batch_beams, next_cell_state, self._cell.state_size) 495 496 if self._output_layer is not None: 497 cell_outputs = self._output_layer(cell_outputs) 498 499 beam_search_output, beam_search_state = _beam_search_step( 500 time=time, 501 logits=cell_outputs, 502 next_cell_state=next_cell_state, 503 beam_state=state, 504 batch_size=batch_size, 505 beam_width=beam_width, 506 end_token=end_token, 507 length_penalty_weight=length_penalty_weight) 508 509 finished = beam_search_state.finished 510 sample_ids = beam_search_output.predicted_ids 511 next_inputs = control_flow_ops.cond( 512 math_ops.reduce_all(finished), lambda: self._start_inputs, 513 lambda: self._embedding_fn(sample_ids)) 514 515 return (beam_search_output, beam_search_state, next_inputs, finished) 516 517 518 def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size, 519 beam_width, end_token, length_penalty_weight): 520 """Performs a single step of Beam Search Decoding. 521 522 Args: 523 time: Beam search time step, should start at 0. At time 0 we assume 524 that all beams are equal and consider only the first beam for 525 continuations. 526 logits: Logits at the current time step. A tensor of shape 527 `[batch_size, beam_width, vocab_size]` 528 next_cell_state: The next state from the cell, e.g. an instance of 529 AttentionWrapperState if the cell is attentional. 530 beam_state: Current state of the beam search. 531 An instance of `BeamSearchDecoderState`. 532 batch_size: The batch size for this input. 533 beam_width: Python int. The size of the beams. 534 end_token: The int32 end token. 535 length_penalty_weight: Float weight to penalize length. Disabled with 0.0. 536 537 Returns: 538 A new beam state. 539 """ 540 static_batch_size = tensor_util.constant_value(batch_size) 541 542 # Calculate the current lengths of the predictions 543 prediction_lengths = beam_state.lengths 544 previously_finished = beam_state.finished 545 546 # Calculate the total log probs for the new hypotheses 547 # Final Shape: [batch_size, beam_width, vocab_size] 548 step_log_probs = nn_ops.log_softmax(logits) 549 step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished) 550 total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs 551 552 # Calculate the continuation lengths by adding to all continuing beams. 553 vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1] 554 lengths_to_add = array_ops.one_hot( 555 indices=array_ops.fill([batch_size, beam_width], end_token), 556 depth=vocab_size, 557 on_value=np.int64(0), 558 off_value=np.int64(1), 559 dtype=dtypes.int64) 560 add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished)) 561 lengths_to_add *= array_ops.expand_dims(add_mask, 2) 562 new_prediction_lengths = ( 563 lengths_to_add + array_ops.expand_dims(prediction_lengths, 2)) 564 565 # Calculate the scores for each beam 566 scores = _get_scores( 567 log_probs=total_probs, 568 sequence_lengths=new_prediction_lengths, 569 length_penalty_weight=length_penalty_weight) 570 571 time = ops.convert_to_tensor(time, name="time") 572 # During the first time step we only consider the initial beam 573 scores_shape = array_ops.shape(scores) 574 scores_flat = array_ops.reshape(scores, [batch_size, -1]) 575 576 # Pick the next beams according to the specified successors function 577 next_beam_size = ops.convert_to_tensor( 578 beam_width, dtype=dtypes.int32, name="beam_width") 579 next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size) 580 581 next_beam_scores.set_shape([static_batch_size, beam_width]) 582 word_indices.set_shape([static_batch_size, beam_width]) 583 584 # Pick out the probs, beam_ids, and states according to the chosen predictions 585 next_beam_probs = _tensor_gather_helper( 586 gather_indices=word_indices, 587 gather_from=total_probs, 588 batch_size=batch_size, 589 range_size=beam_width * vocab_size, 590 gather_shape=[-1], 591 name="next_beam_probs") 592 # Note: just doing the following 593 # math_ops.to_int32(word_indices % vocab_size, 594 # name="next_beam_word_ids") 595 # would be a lot cleaner but for reasons unclear, that hides the results of 596 # the op which prevents capturing it with tfdbg debug ops. 597 raw_next_word_ids = math_ops.mod( 598 word_indices, vocab_size, name="next_beam_word_ids") 599 next_word_ids = math_ops.to_int32(raw_next_word_ids) 600 next_beam_ids = math_ops.to_int32( 601 word_indices / vocab_size, name="next_beam_parent_ids") 602 603 # Append new ids to current predictions 604 previously_finished = _tensor_gather_helper( 605 gather_indices=next_beam_ids, 606 gather_from=previously_finished, 607 batch_size=batch_size, 608 range_size=beam_width, 609 gather_shape=[-1]) 610 next_finished = math_ops.logical_or( 611 previously_finished, 612 math_ops.equal(next_word_ids, end_token), 613 name="next_beam_finished") 614 615 # Calculate the length of the next predictions. 616 # 1. Finished beams remain unchanged. 617 # 2. Beams that are now finished (EOS predicted) have their length 618 # increased by 1. 619 # 3. Beams that are not yet finished have their length increased by 1. 620 lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished)) 621 next_prediction_len = _tensor_gather_helper( 622 gather_indices=next_beam_ids, 623 gather_from=beam_state.lengths, 624 batch_size=batch_size, 625 range_size=beam_width, 626 gather_shape=[-1]) 627 next_prediction_len += lengths_to_add 628 629 # Pick out the cell_states according to the next_beam_ids. We use a 630 # different gather_shape here because the cell_state tensors, i.e. 631 # the tensors that would be gathered from, all have dimension 632 # greater than two and we need to preserve those dimensions. 633 # pylint: disable=g-long-lambda 634 next_cell_state = nest.map_structure( 635 lambda gather_from: _maybe_tensor_gather_helper( 636 gather_indices=next_beam_ids, 637 gather_from=gather_from, 638 batch_size=batch_size, 639 range_size=beam_width, 640 gather_shape=[batch_size * beam_width, -1]), 641 next_cell_state) 642 # pylint: enable=g-long-lambda 643 644 next_state = BeamSearchDecoderState( 645 cell_state=next_cell_state, 646 log_probs=next_beam_probs, 647 lengths=next_prediction_len, 648 finished=next_finished) 649 650 output = BeamSearchDecoderOutput( 651 scores=next_beam_scores, 652 predicted_ids=next_word_ids, 653 parent_ids=next_beam_ids) 654 655 return output, next_state 656 657 658 def _get_scores(log_probs, sequence_lengths, length_penalty_weight): 659 """Calculates scores for beam search hypotheses. 660 661 Args: 662 log_probs: The log probabilities with shape 663 `[batch_size, beam_width, vocab_size]`. 664 sequence_lengths: The array of sequence lengths. 665 length_penalty_weight: Float weight to penalize length. Disabled with 0.0. 666 667 Returns: 668 The scores normalized by the length_penalty. 669 """ 670 length_penality_ = _length_penalty( 671 sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight) 672 return log_probs / length_penality_ 673 674 675 def _length_penalty(sequence_lengths, penalty_factor): 676 """Calculates the length penalty. See https://arxiv.org/abs/1609.08144. 677 678 Returns the length penalty tensor: 679 ``` 680 [(5+sequence_lengths)/6]**penalty_factor 681 ``` 682 where all operations are performed element-wise. 683 684 Args: 685 sequence_lengths: `Tensor`, the sequence lengths of each hypotheses. 686 penalty_factor: A scalar that weights the length penalty. 687 688 Returns: 689 If the penalty is `0`, returns the scalar `1.0`. Otherwise returns 690 the length penalty factor, a tensor with the same shape as 691 `sequence_lengths`. 692 """ 693 penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor") 694 penalty_factor.set_shape(()) # penalty should be a scalar. 695 static_penalty = tensor_util.constant_value(penalty_factor) 696 if static_penalty is not None and static_penalty == 0: 697 return 1.0 698 return math_ops.div((5. + math_ops.to_float(sequence_lengths)) 699 **penalty_factor, (5. + 1.)**penalty_factor) 700 701 702 def _mask_probs(probs, eos_token, finished): 703 """Masks log probabilities. 704 705 The result is that finished beams allocate all probability mass to eos and 706 unfinished beams remain unchanged. 707 708 Args: 709 probs: Log probabiltiies of shape `[batch_size, beam_width, vocab_size]` 710 eos_token: An int32 id corresponding to the EOS token to allocate 711 probability to. 712 finished: A boolean tensor of shape `[batch_size, beam_width]` that 713 specifies which elements in the beam are finished already. 714 715 Returns: 716 A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished 717 beams stay unchanged and finished beams are replaced with a tensor with all 718 probability on the EOS token. 719 """ 720 vocab_size = array_ops.shape(probs)[2] 721 # All finished examples are replaced with a vector that has all 722 # probability on EOS 723 finished_row = array_ops.one_hot( 724 eos_token, 725 vocab_size, 726 dtype=probs.dtype, 727 on_value=0., 728 off_value=probs.dtype.min) 729 finished_probs = array_ops.tile( 730 array_ops.reshape(finished_row, [1, 1, -1]), 731 array_ops.concat([array_ops.shape(finished), [1]], 0)) 732 finished_mask = array_ops.tile( 733 array_ops.expand_dims(finished, 2), [1, 1, vocab_size]) 734 735 return array_ops.where(finished_mask, finished_probs, probs) 736 737 738 def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size, 739 range_size, gather_shape): 740 """Maybe applies _tensor_gather_helper. 741 742 This applies _tensor_gather_helper when the gather_from dims is at least as 743 big as the length of gather_shape. This is used in conjunction with nest so 744 that we don't apply _tensor_gather_helper to inapplicable values like scalars. 745 746 Args: 747 gather_indices: The tensor indices that we use to gather. 748 gather_from: The tensor that we are gathering from. 749 batch_size: The batch size. 750 range_size: The number of values in each range. Likely equal to beam_width. 751 gather_shape: What we should reshape gather_from to in order to preserve the 752 correct values. An example is when gather_from is the attention from an 753 AttentionWrapperState with shape [batch_size, beam_width, attention_size]. 754 There, we want to preserve the attention_size elements, so gather_shape is 755 [batch_size * beam_width, -1]. Then, upon reshape, we still have the 756 attention_size as desired. 757 758 Returns: 759 output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)] 760 or the original tensor if its dimensions are too small. 761 """ 762 _check_maybe(gather_from) 763 if gather_from.shape.ndims >= len(gather_shape): 764 return _tensor_gather_helper( 765 gather_indices=gather_indices, 766 gather_from=gather_from, 767 batch_size=batch_size, 768 range_size=range_size, 769 gather_shape=gather_shape) 770 else: 771 return gather_from 772 773 774 def _tensor_gather_helper(gather_indices, 775 gather_from, 776 batch_size, 777 range_size, 778 gather_shape, 779 name=None): 780 """Helper for gathering the right indices from the tensor. 781 782 This works by reshaping gather_from to gather_shape (e.g. [-1]) and then 783 gathering from that according to the gather_indices, which are offset by 784 the right amounts in order to preserve the batch order. 785 786 Args: 787 gather_indices: The tensor indices that we use to gather. 788 gather_from: The tensor that we are gathering from. 789 batch_size: The input batch size. 790 range_size: The number of values in each range. Likely equal to beam_width. 791 gather_shape: What we should reshape gather_from to in order to preserve the 792 correct values. An example is when gather_from is the attention from an 793 AttentionWrapperState with shape [batch_size, beam_width, attention_size]. 794 There, we want to preserve the attention_size elements, so gather_shape is 795 [batch_size * beam_width, -1]. Then, upon reshape, we still have the 796 attention_size as desired. 797 name: The tensor name for set of operations. By default this is 798 'tensor_gather_helper'. The final output is named 'output'. 799 800 Returns: 801 output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)] 802 """ 803 with ops.name_scope(name, "tensor_gather_helper"): 804 range_ = array_ops.expand_dims(math_ops.range(batch_size) * range_size, 1) 805 gather_indices = array_ops.reshape(gather_indices + range_, [-1]) 806 output = array_ops.gather( 807 array_ops.reshape(gather_from, gather_shape), gather_indices) 808 final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)] 809 static_batch_size = tensor_util.constant_value(batch_size) 810 final_static_shape = ( 811 tensor_shape.TensorShape([static_batch_size]).concatenate( 812 gather_from.shape[1:1 + len(gather_shape)])) 813 output = array_ops.reshape(output, final_shape, name="output") 814 output.set_shape(final_static_shape) 815 return output 816