Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A decoder that performs beam search."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import numpy as np
     23 
     24 from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
     25 from tensorflow.contrib.seq2seq.python.ops import decoder
     26 from tensorflow.python.framework import dtypes
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.framework import tensor_shape
     29 from tensorflow.python.framework import tensor_util
     30 from tensorflow.python.layers import base as layers_base
     31 from tensorflow.python.ops import array_ops
     32 from tensorflow.python.ops import control_flow_ops
     33 from tensorflow.python.ops import embedding_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.ops import nn_ops
     36 from tensorflow.python.ops import rnn_cell_impl
     37 from tensorflow.python.ops import tensor_array_ops
     38 from tensorflow.python.util import nest
     39 
     40 __all__ = [
     41     "BeamSearchDecoderOutput",
     42     "BeamSearchDecoderState",
     43     "BeamSearchDecoder",
     44     "FinalBeamSearchDecoderOutput",
     45     "tile_batch",
     46 ]
     47 
     48 
     49 class BeamSearchDecoderState(
     50     collections.namedtuple("BeamSearchDecoderState",
     51                            ("cell_state", "log_probs", "finished", "lengths"))):
     52   pass
     53 
     54 
     55 class BeamSearchDecoderOutput(
     56     collections.namedtuple("BeamSearchDecoderOutput",
     57                            ("scores", "predicted_ids", "parent_ids"))):
     58   pass
     59 
     60 
     61 class FinalBeamSearchDecoderOutput(
     62     collections.namedtuple("FinalBeamDecoderOutput",
     63                            ["predicted_ids", "beam_search_decoder_output"])):
     64   """Final outputs returned by the beam search after all decoding is finished.
     65 
     66   Args:
     67     predicted_ids: The final prediction. A tensor of shape
     68       `[batch_size, T, beam_width]` (or `[T, batch_size, beam_width]` if
     69       `output_time_major` is True). Beams are ordered from best to worst.
     70     beam_search_decoder_output: An instance of `BeamSearchDecoderOutput` that
     71       describes the state of the beam search.
     72   """
     73   pass
     74 
     75 
     76 def _tile_batch(t, multiplier):
     77   """Core single-tensor implementation of tile_batch."""
     78   t = ops.convert_to_tensor(t, name="t")
     79   shape_t = array_ops.shape(t)
     80   if t.shape.ndims is None or t.shape.ndims < 1:
     81     raise ValueError("t must have statically known rank")
     82   tiling = [1] * (t.shape.ndims + 1)
     83   tiling[1] = multiplier
     84   tiled_static_batch_size = (
     85       t.shape[0].value * multiplier if t.shape[0].value is not None else None)
     86   tiled = array_ops.tile(array_ops.expand_dims(t, 1), tiling)
     87   tiled = array_ops.reshape(tiled,
     88                             array_ops.concat(
     89                                 ([shape_t[0] * multiplier], shape_t[1:]), 0))
     90   tiled.set_shape(
     91       tensor_shape.TensorShape([tiled_static_batch_size]).concatenate(
     92           t.shape[1:]))
     93   return tiled
     94 
     95 
     96 def tile_batch(t, multiplier, name=None):
     97   """Tile the batch dimension of a (possibly nested structure of) tensor(s) t.
     98 
     99   For each tensor t in a (possibly nested structure) of tensors,
    100   this function takes a tensor t shaped `[batch_size, s0, s1, ...]` composed of
    101   minibatch entries `t[0], ..., t[batch_size - 1]` and tiles it to have a shape
    102   `[batch_size * multiplier, s0, s1, ...]` composed of minibatch entries
    103   `t[0], t[0], ..., t[1], t[1], ...` where each minibatch entry is repeated
    104   `multiplier` times.
    105 
    106   Args:
    107     t: `Tensor` shaped `[batch_size, ...]`.
    108     multiplier: Python int.
    109     name: Name scope for any created operations.
    110 
    111   Returns:
    112     A (possibly nested structure of) `Tensor` shaped
    113     `[batch_size * multiplier, ...]`.
    114 
    115   Raises:
    116     ValueError: if tensor(s) `t` do not have a statically known rank or
    117     the rank is < 1.
    118   """
    119   flat_t = nest.flatten(t)
    120   with ops.name_scope(name, "tile_batch", flat_t + [multiplier]):
    121     return nest.map_structure(lambda t_: _tile_batch(t_, multiplier), t)
    122 
    123 
    124 def _check_maybe(t):
    125   if isinstance(t, tensor_array_ops.TensorArray):
    126     raise TypeError(
    127         "TensorArray state is not supported by BeamSearchDecoder: %s" % t.name)
    128   if t.shape.ndims is None:
    129     raise ValueError(
    130         "Expected tensor (%s) to have known rank, but ndims == None." % t)
    131 
    132 
    133 class BeamSearchDecoder(decoder.Decoder):
    134   """BeamSearch sampling decoder.
    135 
    136     **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
    137     `AttentionWrapper`, then you must ensure that:
    138 
    139     - The encoder output has been tiled to `beam_width` via
    140       @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`).
    141     - The `batch_size` argument passed to the `zero_state` method of this
    142       wrapper is equal to `true_batch_size * beam_width`.
    143     - The initial state created with `zero_state` above contains a
    144       `cell_state` value containing properly tiled final state from the
    145       encoder.
    146 
    147     An example:
    148 
    149     ```
    150     tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
    151         encoder_outputs, multiplier=beam_width)
    152     tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
    153         encoder_final_state, multiplier=beam_width)
    154     tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
    155         sequence_length, multiplier=beam_width)
    156     attention_mechanism = MyFavoriteAttentionMechanism(
    157         num_units=attention_depth,
    158         memory=tiled_inputs,
    159         memory_sequence_length=tiled_sequence_length)
    160     attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
    161     decoder_initial_state = attention_cell.zero_state(
    162         dtype, batch_size=true_batch_size * beam_width)
    163     decoder_initial_state = decoder_initial_state.clone(
    164         cell_state=tiled_encoder_final_state)
    165     ```
    166   """
    167 
    168   def __init__(self,
    169                cell,
    170                embedding,
    171                start_tokens,
    172                end_token,
    173                initial_state,
    174                beam_width,
    175                output_layer=None,
    176                length_penalty_weight=0.0):
    177     """Initialize the BeamSearchDecoder.
    178 
    179     Args:
    180       cell: An `RNNCell` instance.
    181       embedding: A callable that takes a vector tensor of `ids` (argmax ids),
    182         or the `params` argument for `embedding_lookup`.
    183       start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
    184       end_token: `int32` scalar, the token that marks end of decoding.
    185       initial_state: A (possibly nested tuple of...) tensors and TensorArrays.
    186       beam_width:  Python integer, the number of beams.
    187       output_layer: (Optional) An instance of `tf.layers.Layer`, i.e.,
    188         `tf.layers.Dense`.  Optional layer to apply to the RNN output prior
    189         to storing the result or sampling.
    190       length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
    191 
    192     Raises:
    193       TypeError: if `cell` is not an instance of `RNNCell`,
    194         or `output_layer` is not an instance of `tf.layers.Layer`.
    195       ValueError: If `start_tokens` is not a vector or
    196         `end_token` is not a scalar.
    197     """
    198     if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
    199       raise TypeError("cell must be an RNNCell, received: %s" % type(cell))
    200     if (output_layer is not None and
    201         not isinstance(output_layer, layers_base.Layer)):
    202       raise TypeError(
    203           "output_layer must be a Layer, received: %s" % type(output_layer))
    204     self._cell = cell
    205     self._output_layer = output_layer
    206 
    207     if callable(embedding):
    208       self._embedding_fn = embedding
    209     else:
    210       self._embedding_fn = (
    211           lambda ids: embedding_ops.embedding_lookup(embedding, ids))
    212 
    213     self._start_tokens = ops.convert_to_tensor(
    214         start_tokens, dtype=dtypes.int32, name="start_tokens")
    215     if self._start_tokens.get_shape().ndims != 1:
    216       raise ValueError("start_tokens must be a vector")
    217     self._end_token = ops.convert_to_tensor(
    218         end_token, dtype=dtypes.int32, name="end_token")
    219     if self._end_token.get_shape().ndims != 0:
    220       raise ValueError("end_token must be a scalar")
    221 
    222     self._batch_size = array_ops.size(start_tokens)
    223     self._beam_width = beam_width
    224     self._length_penalty_weight = length_penalty_weight
    225     self._initial_cell_state = nest.map_structure(
    226         self._maybe_split_batch_beams, initial_state, self._cell.state_size)
    227     self._start_tokens = array_ops.tile(
    228         array_ops.expand_dims(self._start_tokens, 1), [1, self._beam_width])
    229     self._start_inputs = self._embedding_fn(self._start_tokens)
    230 
    231     self._finished = array_ops.one_hot(
    232         array_ops.zeros([self._batch_size], dtype=dtypes.int32),
    233         depth=self._beam_width,
    234         on_value=False,
    235         off_value=True,
    236         dtype=dtypes.bool)
    237 
    238   @property
    239   def batch_size(self):
    240     return self._batch_size
    241 
    242   def _rnn_output_size(self):
    243     size = self._cell.output_size
    244     if self._output_layer is None:
    245       return size
    246     else:
    247       # To use layer's compute_output_shape, we need to convert the
    248       # RNNCell's output_size entries into shapes with an unknown
    249       # batch size.  We then pass this through the layer's
    250       # compute_output_shape and read off all but the first (batch)
    251       # dimensions to get the output size of the rnn with the layer
    252       # applied to the top.
    253       output_shape_with_unknown_batch = nest.map_structure(
    254           lambda s: tensor_shape.TensorShape([None]).concatenate(s), size)
    255       layer_output_shape = self._output_layer.compute_output_shape(
    256           output_shape_with_unknown_batch)
    257       return nest.map_structure(lambda s: s[1:], layer_output_shape)
    258 
    259   @property
    260   def tracks_own_finished(self):
    261     """The BeamSearchDecoder shuffles its beams and their finished state.
    262 
    263     For this reason, it conflicts with the `dynamic_decode` function's
    264     tracking of finished states.  Setting this property to true avoids
    265     early stopping of decoding due to mismanagement of the finished state
    266     in `dynamic_decode`.
    267 
    268     Returns:
    269       `True`.
    270     """
    271     return True
    272 
    273   @property
    274   def output_size(self):
    275     # Return the cell output and the id
    276     return BeamSearchDecoderOutput(
    277         scores=tensor_shape.TensorShape([self._beam_width]),
    278         predicted_ids=tensor_shape.TensorShape([self._beam_width]),
    279         parent_ids=tensor_shape.TensorShape([self._beam_width]))
    280 
    281   @property
    282   def output_dtype(self):
    283     # Assume the dtype of the cell is the output_size structure
    284     # containing the input_state's first component's dtype.
    285     # Return that structure and int32 (the id)
    286     dtype = nest.flatten(self._initial_cell_state)[0].dtype
    287     return BeamSearchDecoderOutput(
    288         scores=nest.map_structure(lambda _: dtype, self._rnn_output_size()),
    289         predicted_ids=dtypes.int32,
    290         parent_ids=dtypes.int32)
    291 
    292   def initialize(self, name=None):
    293     """Initialize the decoder.
    294 
    295     Args:
    296       name: Name scope for any created operations.
    297 
    298     Returns:
    299       `(finished, start_inputs, initial_state)`.
    300     """
    301     finished, start_inputs = self._finished, self._start_inputs
    302 
    303     log_probs = array_ops.one_hot(  # shape(batch_sz, beam_sz)
    304         array_ops.zeros([self._batch_size], dtype=dtypes.int32),
    305         depth=self._beam_width,
    306         on_value=0.0,
    307         off_value=-np.Inf,
    308         dtype=nest.flatten(self._initial_cell_state)[0].dtype)
    309 
    310     initial_state = BeamSearchDecoderState(
    311         cell_state=self._initial_cell_state,
    312         log_probs=log_probs,
    313         finished=finished,
    314         lengths=array_ops.zeros(
    315             [self._batch_size, self._beam_width], dtype=dtypes.int64))
    316 
    317     return (finished, start_inputs, initial_state)
    318 
    319   def finalize(self, outputs, final_state, sequence_lengths):
    320     """Finalize and return the predicted_ids.
    321 
    322     Args:
    323       outputs: An instance of BeamSearchDecoderOutput.
    324       final_state: An instance of BeamSearchDecoderState. Passed through to the
    325         output.
    326       sequence_lengths: An `int64` tensor shaped `[batch_size, beam_width]`.
    327         The sequence lengths determined for each beam during decode.
    328         **NOTE** These are ignored; the updated sequence lengths are stored in
    329         `final_state.lengths`.
    330 
    331     Returns:
    332       outputs: An instance of `FinalBeamSearchDecoderOutput` where the
    333         predicted_ids are the result of calling _gather_tree.
    334       final_state: The same input instance of `BeamSearchDecoderState`.
    335     """
    336     del sequence_lengths
    337     # Get max_sequence_length across all beams for each batch.
    338     max_sequence_lengths = math_ops.to_int32(
    339         math_ops.reduce_max(final_state.lengths, axis=1))
    340     predicted_ids = beam_search_ops.gather_tree(
    341         outputs.predicted_ids,
    342         outputs.parent_ids,
    343         max_sequence_lengths=max_sequence_lengths,
    344         end_token=self._end_token)
    345     outputs = FinalBeamSearchDecoderOutput(
    346         beam_search_decoder_output=outputs, predicted_ids=predicted_ids)
    347     return outputs, final_state
    348 
    349   def _merge_batch_beams(self, t, s=None):
    350     """Merges the tensor from a batch of beams into a batch by beams.
    351 
    352     More exactly, t is a tensor of dimension [batch_size, beam_width, s]. We
    353     reshape this into [batch_size*beam_width, s]
    354 
    355     Args:
    356       t: Tensor of dimension [batch_size, beam_width, s]
    357       s: (Possibly known) depth shape.
    358 
    359     Returns:
    360       A reshaped version of t with dimension [batch_size * beam_width, s].
    361     """
    362     if isinstance(s, ops.Tensor):
    363       s = tensor_shape.as_shape(tensor_util.constant_value(s))
    364     else:
    365       s = tensor_shape.TensorShape(s)
    366     t_shape = array_ops.shape(t)
    367     static_batch_size = tensor_util.constant_value(self._batch_size)
    368     batch_size_beam_width = (
    369         None
    370         if static_batch_size is None else static_batch_size * self._beam_width)
    371     reshaped_t = array_ops.reshape(
    372         t,
    373         array_ops.concat(([self._batch_size * self._beam_width], t_shape[2:]),
    374                          0))
    375     reshaped_t.set_shape(
    376         (tensor_shape.TensorShape([batch_size_beam_width]).concatenate(s)))
    377     return reshaped_t
    378 
    379   def _split_batch_beams(self, t, s=None):
    380     """Splits the tensor from a batch by beams into a batch of beams.
    381 
    382     More exactly, t is a tensor of dimension [batch_size*beam_width, s]. We
    383     reshape this into [batch_size, beam_width, s]
    384 
    385     Args:
    386       t: Tensor of dimension [batch_size*beam_width, s].
    387       s: (Possibly known) depth shape.
    388 
    389     Returns:
    390       A reshaped version of t with dimension [batch_size, beam_width, s].
    391 
    392     Raises:
    393       ValueError: If, after reshaping, the new tensor is not shaped
    394         `[batch_size, beam_width, s]` (assuming batch_size and beam_width
    395         are known statically).
    396     """
    397     if isinstance(s, ops.Tensor):
    398       s = tensor_shape.TensorShape(tensor_util.constant_value(s))
    399     else:
    400       s = tensor_shape.TensorShape(s)
    401     t_shape = array_ops.shape(t)
    402     reshaped_t = array_ops.reshape(
    403         t,
    404         array_ops.concat(([self._batch_size, self._beam_width], t_shape[1:]),
    405                          0))
    406     static_batch_size = tensor_util.constant_value(self._batch_size)
    407     expected_reshaped_shape = tensor_shape.TensorShape(
    408         [static_batch_size, self._beam_width]).concatenate(s)
    409     if not reshaped_t.shape.is_compatible_with(expected_reshaped_shape):
    410       raise ValueError("Unexpected behavior when reshaping between beam width "
    411                        "and batch size.  The reshaped tensor has shape: %s.  "
    412                        "We expected it to have shape "
    413                        "(batch_size, beam_width, depth) == %s.  Perhaps you "
    414                        "forgot to create a zero_state with "
    415                        "batch_size=encoder_batch_size * beam_width?" %
    416                        (reshaped_t.shape, expected_reshaped_shape))
    417     reshaped_t.set_shape(expected_reshaped_shape)
    418     return reshaped_t
    419 
    420   def _maybe_split_batch_beams(self, t, s):
    421     """Maybe splits the tensor from a batch by beams into a batch of beams.
    422 
    423     We do this so that we can use nest and not run into problems with shapes.
    424 
    425     Args:
    426       t: `Tensor`, either scalar or shaped `[batch_size * beam_width] + s`.
    427       s: `Tensor`, Python int, or `TensorShape`.
    428 
    429     Returns:
    430       If `t` is a matrix or higher order tensor, then the return value is
    431       `t` reshaped to `[batch_size, beam_width] + s`.  Otherwise `t` is
    432       returned unchanged.
    433 
    434     Raises:
    435       TypeError: If `t` is an instance of `TensorArray`.
    436       ValueError: If the rank of `t` is not statically known.
    437     """
    438     _check_maybe(t)
    439     if t.shape.ndims >= 1:
    440       return self._split_batch_beams(t, s)
    441     else:
    442       return t
    443 
    444   def _maybe_merge_batch_beams(self, t, s):
    445     """Splits the tensor from a batch by beams into a batch of beams.
    446 
    447     More exactly, `t` is a tensor of dimension `[batch_size * beam_width] + s`,
    448     then we reshape it to `[batch_size, beam_width] + s`.
    449 
    450     Args:
    451       t: `Tensor` of dimension `[batch_size * beam_width] + s`.
    452       s: `Tensor`, Python int, or `TensorShape`.
    453 
    454     Returns:
    455       A reshaped version of t with shape `[batch_size, beam_width] + s`.
    456 
    457     Raises:
    458       TypeError: If `t` is an instance of `TensorArray`.
    459       ValueError:  If the rank of `t` is not statically known.
    460     """
    461     _check_maybe(t)
    462     if t.shape.ndims >= 2:
    463       return self._merge_batch_beams(t, s)
    464     else:
    465       return t
    466 
    467   def step(self, time, inputs, state, name=None):
    468     """Perform a decoding step.
    469 
    470     Args:
    471       time: scalar `int32` tensor.
    472       inputs: A (structure of) input tensors.
    473       state: A (structure of) state tensors and TensorArrays.
    474       name: Name scope for any created operations.
    475 
    476     Returns:
    477       `(outputs, next_state, next_inputs, finished)`.
    478     """
    479     batch_size = self._batch_size
    480     beam_width = self._beam_width
    481     end_token = self._end_token
    482     length_penalty_weight = self._length_penalty_weight
    483 
    484     with ops.name_scope(name, "BeamSearchDecoderStep", (time, inputs, state)):
    485       cell_state = state.cell_state
    486       inputs = nest.map_structure(
    487           lambda inp: self._merge_batch_beams(inp, s=inp.shape[2:]), inputs)
    488       cell_state = nest.map_structure(self._maybe_merge_batch_beams, cell_state,
    489                                       self._cell.state_size)
    490       cell_outputs, next_cell_state = self._cell(inputs, cell_state)
    491       cell_outputs = nest.map_structure(
    492           lambda out: self._split_batch_beams(out, out.shape[1:]), cell_outputs)
    493       next_cell_state = nest.map_structure(
    494           self._maybe_split_batch_beams, next_cell_state, self._cell.state_size)
    495 
    496       if self._output_layer is not None:
    497         cell_outputs = self._output_layer(cell_outputs)
    498 
    499       beam_search_output, beam_search_state = _beam_search_step(
    500           time=time,
    501           logits=cell_outputs,
    502           next_cell_state=next_cell_state,
    503           beam_state=state,
    504           batch_size=batch_size,
    505           beam_width=beam_width,
    506           end_token=end_token,
    507           length_penalty_weight=length_penalty_weight)
    508 
    509       finished = beam_search_state.finished
    510       sample_ids = beam_search_output.predicted_ids
    511       next_inputs = control_flow_ops.cond(
    512           math_ops.reduce_all(finished), lambda: self._start_inputs,
    513           lambda: self._embedding_fn(sample_ids))
    514 
    515     return (beam_search_output, beam_search_state, next_inputs, finished)
    516 
    517 
    518 def _beam_search_step(time, logits, next_cell_state, beam_state, batch_size,
    519                       beam_width, end_token, length_penalty_weight):
    520   """Performs a single step of Beam Search Decoding.
    521 
    522   Args:
    523     time: Beam search time step, should start at 0. At time 0 we assume
    524       that all beams are equal and consider only the first beam for
    525       continuations.
    526     logits: Logits at the current time step. A tensor of shape
    527       `[batch_size, beam_width, vocab_size]`
    528     next_cell_state: The next state from the cell, e.g. an instance of
    529       AttentionWrapperState if the cell is attentional.
    530     beam_state: Current state of the beam search.
    531       An instance of `BeamSearchDecoderState`.
    532     batch_size: The batch size for this input.
    533     beam_width: Python int.  The size of the beams.
    534     end_token: The int32 end token.
    535     length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
    536 
    537   Returns:
    538     A new beam state.
    539   """
    540   static_batch_size = tensor_util.constant_value(batch_size)
    541 
    542   # Calculate the current lengths of the predictions
    543   prediction_lengths = beam_state.lengths
    544   previously_finished = beam_state.finished
    545 
    546   # Calculate the total log probs for the new hypotheses
    547   # Final Shape: [batch_size, beam_width, vocab_size]
    548   step_log_probs = nn_ops.log_softmax(logits)
    549   step_log_probs = _mask_probs(step_log_probs, end_token, previously_finished)
    550   total_probs = array_ops.expand_dims(beam_state.log_probs, 2) + step_log_probs
    551 
    552   # Calculate the continuation lengths by adding to all continuing beams.
    553   vocab_size = logits.shape[-1].value or array_ops.shape(logits)[-1]
    554   lengths_to_add = array_ops.one_hot(
    555       indices=array_ops.fill([batch_size, beam_width], end_token),
    556       depth=vocab_size,
    557       on_value=np.int64(0),
    558       off_value=np.int64(1),
    559       dtype=dtypes.int64)
    560   add_mask = math_ops.to_int64(math_ops.logical_not(previously_finished))
    561   lengths_to_add *= array_ops.expand_dims(add_mask, 2)
    562   new_prediction_lengths = (
    563       lengths_to_add + array_ops.expand_dims(prediction_lengths, 2))
    564 
    565   # Calculate the scores for each beam
    566   scores = _get_scores(
    567       log_probs=total_probs,
    568       sequence_lengths=new_prediction_lengths,
    569       length_penalty_weight=length_penalty_weight)
    570 
    571   time = ops.convert_to_tensor(time, name="time")
    572   # During the first time step we only consider the initial beam
    573   scores_shape = array_ops.shape(scores)
    574   scores_flat = array_ops.reshape(scores, [batch_size, -1])
    575 
    576   # Pick the next beams according to the specified successors function
    577   next_beam_size = ops.convert_to_tensor(
    578       beam_width, dtype=dtypes.int32, name="beam_width")
    579   next_beam_scores, word_indices = nn_ops.top_k(scores_flat, k=next_beam_size)
    580 
    581   next_beam_scores.set_shape([static_batch_size, beam_width])
    582   word_indices.set_shape([static_batch_size, beam_width])
    583 
    584   # Pick out the probs, beam_ids, and states according to the chosen predictions
    585   next_beam_probs = _tensor_gather_helper(
    586       gather_indices=word_indices,
    587       gather_from=total_probs,
    588       batch_size=batch_size,
    589       range_size=beam_width * vocab_size,
    590       gather_shape=[-1],
    591       name="next_beam_probs")
    592   # Note: just doing the following
    593   #   math_ops.to_int32(word_indices % vocab_size,
    594   #       name="next_beam_word_ids")
    595   # would be a lot cleaner but for reasons unclear, that hides the results of
    596   # the op which prevents capturing it with tfdbg debug ops.
    597   raw_next_word_ids = math_ops.mod(
    598       word_indices, vocab_size, name="next_beam_word_ids")
    599   next_word_ids = math_ops.to_int32(raw_next_word_ids)
    600   next_beam_ids = math_ops.to_int32(
    601       word_indices / vocab_size, name="next_beam_parent_ids")
    602 
    603   # Append new ids to current predictions
    604   previously_finished = _tensor_gather_helper(
    605       gather_indices=next_beam_ids,
    606       gather_from=previously_finished,
    607       batch_size=batch_size,
    608       range_size=beam_width,
    609       gather_shape=[-1])
    610   next_finished = math_ops.logical_or(
    611       previously_finished,
    612       math_ops.equal(next_word_ids, end_token),
    613       name="next_beam_finished")
    614 
    615   # Calculate the length of the next predictions.
    616   # 1. Finished beams remain unchanged.
    617   # 2. Beams that are now finished (EOS predicted) have their length
    618   #    increased by 1.
    619   # 3. Beams that are not yet finished have their length increased by 1.
    620   lengths_to_add = math_ops.to_int64(math_ops.logical_not(previously_finished))
    621   next_prediction_len = _tensor_gather_helper(
    622       gather_indices=next_beam_ids,
    623       gather_from=beam_state.lengths,
    624       batch_size=batch_size,
    625       range_size=beam_width,
    626       gather_shape=[-1])
    627   next_prediction_len += lengths_to_add
    628 
    629   # Pick out the cell_states according to the next_beam_ids. We use a
    630   # different gather_shape here because the cell_state tensors, i.e.
    631   # the tensors that would be gathered from, all have dimension
    632   # greater than two and we need to preserve those dimensions.
    633   # pylint: disable=g-long-lambda
    634   next_cell_state = nest.map_structure(
    635       lambda gather_from: _maybe_tensor_gather_helper(
    636           gather_indices=next_beam_ids,
    637           gather_from=gather_from,
    638           batch_size=batch_size,
    639           range_size=beam_width,
    640           gather_shape=[batch_size * beam_width, -1]),
    641       next_cell_state)
    642   # pylint: enable=g-long-lambda
    643 
    644   next_state = BeamSearchDecoderState(
    645       cell_state=next_cell_state,
    646       log_probs=next_beam_probs,
    647       lengths=next_prediction_len,
    648       finished=next_finished)
    649 
    650   output = BeamSearchDecoderOutput(
    651       scores=next_beam_scores,
    652       predicted_ids=next_word_ids,
    653       parent_ids=next_beam_ids)
    654 
    655   return output, next_state
    656 
    657 
    658 def _get_scores(log_probs, sequence_lengths, length_penalty_weight):
    659   """Calculates scores for beam search hypotheses.
    660 
    661   Args:
    662     log_probs: The log probabilities with shape
    663       `[batch_size, beam_width, vocab_size]`.
    664     sequence_lengths: The array of sequence lengths.
    665     length_penalty_weight: Float weight to penalize length. Disabled with 0.0.
    666 
    667   Returns:
    668     The scores normalized by the length_penalty.
    669   """
    670   length_penality_ = _length_penalty(
    671       sequence_lengths=sequence_lengths, penalty_factor=length_penalty_weight)
    672   return log_probs / length_penality_
    673 
    674 
    675 def _length_penalty(sequence_lengths, penalty_factor):
    676   """Calculates the length penalty. See https://arxiv.org/abs/1609.08144.
    677 
    678   Returns the length penalty tensor:
    679   ```
    680   [(5+sequence_lengths)/6]**penalty_factor
    681   ```
    682   where all operations are performed element-wise.
    683 
    684   Args:
    685     sequence_lengths: `Tensor`, the sequence lengths of each hypotheses.
    686     penalty_factor: A scalar that weights the length penalty.
    687 
    688   Returns:
    689     If the penalty is `0`, returns the scalar `1.0`.  Otherwise returns
    690     the length penalty factor, a tensor with the same shape as
    691     `sequence_lengths`.
    692   """
    693   penalty_factor = ops.convert_to_tensor(penalty_factor, name="penalty_factor")
    694   penalty_factor.set_shape(())  # penalty should be a scalar.
    695   static_penalty = tensor_util.constant_value(penalty_factor)
    696   if static_penalty is not None and static_penalty == 0:
    697     return 1.0
    698   return math_ops.div((5. + math_ops.to_float(sequence_lengths))
    699                       **penalty_factor, (5. + 1.)**penalty_factor)
    700 
    701 
    702 def _mask_probs(probs, eos_token, finished):
    703   """Masks log probabilities.
    704 
    705   The result is that finished beams allocate all probability mass to eos and
    706   unfinished beams remain unchanged.
    707 
    708   Args:
    709     probs: Log probabiltiies of shape `[batch_size, beam_width, vocab_size]`
    710     eos_token: An int32 id corresponding to the EOS token to allocate
    711       probability to.
    712     finished: A boolean tensor of shape `[batch_size, beam_width]` that
    713       specifies which elements in the beam are finished already.
    714 
    715   Returns:
    716     A tensor of shape `[batch_size, beam_width, vocab_size]`, where unfinished
    717     beams stay unchanged and finished beams are replaced with a tensor with all
    718     probability on the EOS token.
    719   """
    720   vocab_size = array_ops.shape(probs)[2]
    721   # All finished examples are replaced with a vector that has all
    722   # probability on EOS
    723   finished_row = array_ops.one_hot(
    724       eos_token,
    725       vocab_size,
    726       dtype=probs.dtype,
    727       on_value=0.,
    728       off_value=probs.dtype.min)
    729   finished_probs = array_ops.tile(
    730       array_ops.reshape(finished_row, [1, 1, -1]),
    731       array_ops.concat([array_ops.shape(finished), [1]], 0))
    732   finished_mask = array_ops.tile(
    733       array_ops.expand_dims(finished, 2), [1, 1, vocab_size])
    734 
    735   return array_ops.where(finished_mask, finished_probs, probs)
    736 
    737 
    738 def _maybe_tensor_gather_helper(gather_indices, gather_from, batch_size,
    739                                 range_size, gather_shape):
    740   """Maybe applies _tensor_gather_helper.
    741 
    742   This applies _tensor_gather_helper when the gather_from dims is at least as
    743   big as the length of gather_shape. This is used in conjunction with nest so
    744   that we don't apply _tensor_gather_helper to inapplicable values like scalars.
    745 
    746   Args:
    747     gather_indices: The tensor indices that we use to gather.
    748     gather_from: The tensor that we are gathering from.
    749     batch_size: The batch size.
    750     range_size: The number of values in each range. Likely equal to beam_width.
    751     gather_shape: What we should reshape gather_from to in order to preserve the
    752       correct values. An example is when gather_from is the attention from an
    753       AttentionWrapperState with shape [batch_size, beam_width, attention_size].
    754       There, we want to preserve the attention_size elements, so gather_shape is
    755       [batch_size * beam_width, -1]. Then, upon reshape, we still have the
    756       attention_size as desired.
    757 
    758   Returns:
    759     output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
    760       or the original tensor if its dimensions are too small.
    761   """
    762   _check_maybe(gather_from)
    763   if gather_from.shape.ndims >= len(gather_shape):
    764     return _tensor_gather_helper(
    765         gather_indices=gather_indices,
    766         gather_from=gather_from,
    767         batch_size=batch_size,
    768         range_size=range_size,
    769         gather_shape=gather_shape)
    770   else:
    771     return gather_from
    772 
    773 
    774 def _tensor_gather_helper(gather_indices,
    775                           gather_from,
    776                           batch_size,
    777                           range_size,
    778                           gather_shape,
    779                           name=None):
    780   """Helper for gathering the right indices from the tensor.
    781 
    782   This works by reshaping gather_from to gather_shape (e.g. [-1]) and then
    783   gathering from that according to the gather_indices, which are offset by
    784   the right amounts in order to preserve the batch order.
    785 
    786   Args:
    787     gather_indices: The tensor indices that we use to gather.
    788     gather_from: The tensor that we are gathering from.
    789     batch_size: The input batch size.
    790     range_size: The number of values in each range. Likely equal to beam_width.
    791     gather_shape: What we should reshape gather_from to in order to preserve the
    792       correct values. An example is when gather_from is the attention from an
    793       AttentionWrapperState with shape [batch_size, beam_width, attention_size].
    794       There, we want to preserve the attention_size elements, so gather_shape is
    795       [batch_size * beam_width, -1]. Then, upon reshape, we still have the
    796       attention_size as desired.
    797     name: The tensor name for set of operations. By default this is
    798       'tensor_gather_helper'. The final output is named 'output'.
    799 
    800   Returns:
    801     output: Gathered tensor of shape tf.shape(gather_from)[:1+len(gather_shape)]
    802   """
    803   with ops.name_scope(name, "tensor_gather_helper"):
    804     range_ = array_ops.expand_dims(math_ops.range(batch_size) * range_size, 1)
    805     gather_indices = array_ops.reshape(gather_indices + range_, [-1])
    806     output = array_ops.gather(
    807         array_ops.reshape(gather_from, gather_shape), gather_indices)
    808     final_shape = array_ops.shape(gather_from)[:1 + len(gather_shape)]
    809     static_batch_size = tensor_util.constant_value(batch_size)
    810     final_static_shape = (
    811         tensor_shape.TensorShape([static_batch_size]).concatenate(
    812             gather_from.shape[1:1 + len(gather_shape)]))
    813     output = array_ops.reshape(output, final_shape, name="output")
    814     output.set_shape(final_static_shape)
    815     return output
    816