Home | History | Annotate | Download | only in ops
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A powerful dynamic attention wrapper object."""
     16 
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 
     21 import collections
     22 import functools
     23 import math
     24 
     25 import numpy as np
     26 
     27 from tensorflow.contrib.framework.python.framework import tensor_util
     28 from tensorflow.python.framework import dtypes
     29 from tensorflow.python.framework import ops
     30 from tensorflow.python.framework import tensor_shape
     31 from tensorflow.python.layers import base as layers_base
     32 from tensorflow.python.layers import core as layers_core
     33 from tensorflow.python.ops import array_ops
     34 from tensorflow.python.ops import check_ops
     35 from tensorflow.python.ops import clip_ops
     36 from tensorflow.python.ops import functional_ops
     37 from tensorflow.python.ops import init_ops
     38 from tensorflow.python.ops import math_ops
     39 from tensorflow.python.ops import nn_ops
     40 from tensorflow.python.ops import random_ops
     41 from tensorflow.python.ops import rnn_cell_impl
     42 from tensorflow.python.ops import tensor_array_ops
     43 from tensorflow.python.ops import variable_scope
     44 from tensorflow.python.util import nest
     45 
     46 
     47 __all__ = [
     48     "AttentionMechanism",
     49     "AttentionWrapper",
     50     "AttentionWrapperState",
     51     "LuongAttention",
     52     "BahdanauAttention",
     53     "hardmax",
     54     "safe_cumprod",
     55     "monotonic_attention",
     56     "BahdanauMonotonicAttention",
     57     "LuongMonotonicAttention",
     58 ]
     59 
     60 
     61 _zero_state_tensors = rnn_cell_impl._zero_state_tensors  # pylint: disable=protected-access
     62 
     63 
     64 class AttentionMechanism(object):
     65 
     66   @property
     67   def alignments_size(self):
     68     raise NotImplementedError
     69 
     70   @property
     71   def state_size(self):
     72     raise NotImplementedError
     73 
     74 
     75 def _prepare_memory(memory, memory_sequence_length, check_inner_dims_defined):
     76   """Convert to tensor and possibly mask `memory`.
     77 
     78   Args:
     79     memory: `Tensor`, shaped `[batch_size, max_time, ...]`.
     80     memory_sequence_length: `int32` `Tensor`, shaped `[batch_size]`.
     81     check_inner_dims_defined: Python boolean.  If `True`, the `memory`
     82       argument's shape is checked to ensure all but the two outermost
     83       dimensions are fully defined.
     84 
     85   Returns:
     86     A (possibly masked), checked, new `memory`.
     87 
     88   Raises:
     89     ValueError: If `check_inner_dims_defined` is `True` and not
     90       `memory.shape[2:].is_fully_defined()`.
     91   """
     92   memory = nest.map_structure(
     93       lambda m: ops.convert_to_tensor(m, name="memory"), memory)
     94   if memory_sequence_length is not None:
     95     memory_sequence_length = ops.convert_to_tensor(
     96         memory_sequence_length, name="memory_sequence_length")
     97   if check_inner_dims_defined:
     98     def _check_dims(m):
     99       if not m.get_shape()[2:].is_fully_defined():
    100         raise ValueError("Expected memory %s to have fully defined inner dims, "
    101                          "but saw shape: %s" % (m.name, m.get_shape()))
    102     nest.map_structure(_check_dims, memory)
    103   if memory_sequence_length is None:
    104     seq_len_mask = None
    105   else:
    106     seq_len_mask = array_ops.sequence_mask(
    107         memory_sequence_length,
    108         maxlen=array_ops.shape(nest.flatten(memory)[0])[1],
    109         dtype=nest.flatten(memory)[0].dtype)
    110     seq_len_batch_size = (
    111         memory_sequence_length.shape[0].value
    112         or array_ops.shape(memory_sequence_length)[0])
    113   def _maybe_mask(m, seq_len_mask):
    114     rank = m.get_shape().ndims
    115     rank = rank if rank is not None else array_ops.rank(m)
    116     extra_ones = array_ops.ones(rank - 2, dtype=dtypes.int32)
    117     m_batch_size = m.shape[0].value or array_ops.shape(m)[0]
    118     if memory_sequence_length is not None:
    119       message = ("memory_sequence_length and memory tensor batch sizes do not "
    120                  "match.")
    121       with ops.control_dependencies([
    122           check_ops.assert_equal(
    123               seq_len_batch_size, m_batch_size, message=message)]):
    124         seq_len_mask = array_ops.reshape(
    125             seq_len_mask,
    126             array_ops.concat((array_ops.shape(seq_len_mask), extra_ones), 0))
    127         return m * seq_len_mask
    128     else:
    129       return m
    130   return nest.map_structure(lambda m: _maybe_mask(m, seq_len_mask), memory)
    131 
    132 
    133 def _maybe_mask_score(score, memory_sequence_length, score_mask_value):
    134   if memory_sequence_length is None:
    135     return score
    136   message = ("All values in memory_sequence_length must greater than zero.")
    137   with ops.control_dependencies(
    138       [check_ops.assert_positive(memory_sequence_length, message=message)]):
    139     score_mask = array_ops.sequence_mask(
    140         memory_sequence_length, maxlen=array_ops.shape(score)[1])
    141     score_mask_values = score_mask_value * array_ops.ones_like(score)
    142     return array_ops.where(score_mask, score, score_mask_values)
    143 
    144 
    145 class _BaseAttentionMechanism(AttentionMechanism):
    146   """A base AttentionMechanism class providing common functionality.
    147 
    148   Common functionality includes:
    149     1. Storing the query and memory layers.
    150     2. Preprocessing and storing the memory.
    151   """
    152 
    153   def __init__(self,
    154                query_layer,
    155                memory,
    156                probability_fn,
    157                memory_sequence_length=None,
    158                memory_layer=None,
    159                check_inner_dims_defined=True,
    160                score_mask_value=None,
    161                name=None):
    162     """Construct base AttentionMechanism class.
    163 
    164     Args:
    165       query_layer: Callable.  Instance of `tf.layers.Layer`.  The layer's depth
    166         must match the depth of `memory_layer`.  If `query_layer` is not
    167         provided, the shape of `query` must match that of `memory_layer`.
    168       memory: The memory to query; usually the output of an RNN encoder.  This
    169         tensor should be shaped `[batch_size, max_time, ...]`.
    170       probability_fn: A `callable`.  Converts the score and previous alignments
    171         to probabilities. Its signature should be:
    172         `probabilities = probability_fn(score, state)`.
    173       memory_sequence_length (optional): Sequence lengths for the batch entries
    174         in memory.  If provided, the memory tensor rows are masked with zeros
    175         for values past the respective sequence lengths.
    176       memory_layer: Instance of `tf.layers.Layer` (may be None).  The layer's
    177         depth must match the depth of `query_layer`.
    178         If `memory_layer` is not provided, the shape of `memory` must match
    179         that of `query_layer`.
    180       check_inner_dims_defined: Python boolean.  If `True`, the `memory`
    181         argument's shape is checked to ensure all but the two outermost
    182         dimensions are fully defined.
    183       score_mask_value: (optional): The mask value for score before passing into
    184         `probability_fn`. The default is -inf. Only used if
    185         `memory_sequence_length` is not None.
    186       name: Name to use when creating ops.
    187     """
    188     if (query_layer is not None
    189         and not isinstance(query_layer, layers_base.Layer)):
    190       raise TypeError(
    191           "query_layer is not a Layer: %s" % type(query_layer).__name__)
    192     if (memory_layer is not None
    193         and not isinstance(memory_layer, layers_base.Layer)):
    194       raise TypeError(
    195           "memory_layer is not a Layer: %s" % type(memory_layer).__name__)
    196     self._query_layer = query_layer
    197     self._memory_layer = memory_layer
    198     self.dtype = memory_layer.dtype
    199     if not callable(probability_fn):
    200       raise TypeError("probability_fn must be callable, saw type: %s" %
    201                       type(probability_fn).__name__)
    202     if score_mask_value is None:
    203       score_mask_value = dtypes.as_dtype(
    204           self._memory_layer.dtype).as_numpy_dtype(-np.inf)
    205     self._probability_fn = lambda score, prev: (  # pylint:disable=g-long-lambda
    206         probability_fn(
    207             _maybe_mask_score(score, memory_sequence_length, score_mask_value),
    208             prev))
    209     with ops.name_scope(
    210         name, "BaseAttentionMechanismInit", nest.flatten(memory)):
    211       self._values = _prepare_memory(
    212           memory, memory_sequence_length,
    213           check_inner_dims_defined=check_inner_dims_defined)
    214       self._keys = (
    215           self.memory_layer(self._values) if self.memory_layer  # pylint: disable=not-callable
    216           else self._values)
    217       self._batch_size = (
    218           self._keys.shape[0].value or array_ops.shape(self._keys)[0])
    219       self._alignments_size = (self._keys.shape[1].value or
    220                                array_ops.shape(self._keys)[1])
    221 
    222   @property
    223   def memory_layer(self):
    224     return self._memory_layer
    225 
    226   @property
    227   def query_layer(self):
    228     return self._query_layer
    229 
    230   @property
    231   def values(self):
    232     return self._values
    233 
    234   @property
    235   def keys(self):
    236     return self._keys
    237 
    238   @property
    239   def batch_size(self):
    240     return self._batch_size
    241 
    242   @property
    243   def alignments_size(self):
    244     return self._alignments_size
    245 
    246   @property
    247   def state_size(self):
    248     return self._alignments_size
    249 
    250   def initial_alignments(self, batch_size, dtype):
    251     """Creates the initial alignment values for the `AttentionWrapper` class.
    252 
    253     This is important for AttentionMechanisms that use the previous alignment
    254     to calculate the alignment at the next time step (e.g. monotonic attention).
    255 
    256     The default behavior is to return a tensor of all zeros.
    257 
    258     Args:
    259       batch_size: `int32` scalar, the batch_size.
    260       dtype: The `dtype`.
    261 
    262     Returns:
    263       A `dtype` tensor shaped `[batch_size, alignments_size]`
    264       (`alignments_size` is the values' `max_time`).
    265     """
    266     max_time = self._alignments_size
    267     return _zero_state_tensors(max_time, batch_size, dtype)
    268 
    269   def initial_state(self, batch_size, dtype):
    270     """Creates the initial state values for the `AttentionWrapper` class.
    271 
    272     This is important for AttentionMechanisms that use the previous alignment
    273     to calculate the alignment at the next time step (e.g. monotonic attention).
    274 
    275     The default behavior is to return the same output as initial_alignments.
    276 
    277     Args:
    278       batch_size: `int32` scalar, the batch_size.
    279       dtype: The `dtype`.
    280 
    281     Returns:
    282       A structure of all-zero tensors with shapes as described by `state_size`.
    283     """
    284     return self.initial_alignments(batch_size, dtype)
    285 
    286 
    287 def _luong_score(query, keys, scale):
    288   """Implements Luong-style (multiplicative) scoring function.
    289 
    290   This attention has two forms.  The first is standard Luong attention,
    291   as described in:
    292 
    293   Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
    294   "Effective Approaches to Attention-based Neural Machine Translation."
    295   EMNLP 2015.  https://arxiv.org/abs/1508.04025
    296 
    297   The second is the scaled form inspired partly by the normalized form of
    298   Bahdanau attention.
    299 
    300   To enable the second form, call this function with `scale=True`.
    301 
    302   Args:
    303     query: Tensor, shape `[batch_size, num_units]` to compare to keys.
    304     keys: Processed memory, shape `[batch_size, max_time, num_units]`.
    305     scale: Whether to apply a scale to the score function.
    306 
    307   Returns:
    308     A `[batch_size, max_time]` tensor of unnormalized score values.
    309 
    310   Raises:
    311     ValueError: If `key` and `query` depths do not match.
    312   """
    313   depth = query.get_shape()[-1]
    314   key_units = keys.get_shape()[-1]
    315   if depth != key_units:
    316     raise ValueError(
    317         "Incompatible or unknown inner dimensions between query and keys.  "
    318         "Query (%s) has units: %s.  Keys (%s) have units: %s.  "
    319         "Perhaps you need to set num_units to the keys' dimension (%s)?"
    320         % (query, depth, keys, key_units, key_units))
    321   dtype = query.dtype
    322 
    323   # Reshape from [batch_size, depth] to [batch_size, 1, depth]
    324   # for matmul.
    325   query = array_ops.expand_dims(query, 1)
    326 
    327   # Inner product along the query units dimension.
    328   # matmul shapes: query is [batch_size, 1, depth] and
    329   #                keys is [batch_size, max_time, depth].
    330   # the inner product is asked to **transpose keys' inner shape** to get a
    331   # batched matmul on:
    332   #   [batch_size, 1, depth] . [batch_size, depth, max_time]
    333   # resulting in an output shape of:
    334   #   [batch_size, 1, max_time].
    335   # we then squeeze out the center singleton dimension.
    336   score = math_ops.matmul(query, keys, transpose_b=True)
    337   score = array_ops.squeeze(score, [1])
    338 
    339   if scale:
    340     # Scalar used in weight scaling
    341     g = variable_scope.get_variable(
    342         "attention_g", dtype=dtype, initializer=1.)
    343     score = g * score
    344   return score
    345 
    346 
    347 class LuongAttention(_BaseAttentionMechanism):
    348   """Implements Luong-style (multiplicative) attention scoring.
    349 
    350   This attention has two forms.  The first is standard Luong attention,
    351   as described in:
    352 
    353   Minh-Thang Luong, Hieu Pham, Christopher D. Manning.
    354   "Effective Approaches to Attention-based Neural Machine Translation."
    355   EMNLP 2015.  https://arxiv.org/abs/1508.04025
    356 
    357   The second is the scaled form inspired partly by the normalized form of
    358   Bahdanau attention.
    359 
    360   To enable the second form, construct the object with parameter
    361   `scale=True`.
    362   """
    363 
    364   def __init__(self,
    365                num_units,
    366                memory,
    367                memory_sequence_length=None,
    368                scale=False,
    369                probability_fn=None,
    370                score_mask_value=None,
    371                dtype=None,
    372                name="LuongAttention"):
    373     """Construct the AttentionMechanism mechanism.
    374 
    375     Args:
    376       num_units: The depth of the attention mechanism.
    377       memory: The memory to query; usually the output of an RNN encoder.  This
    378         tensor should be shaped `[batch_size, max_time, ...]`.
    379       memory_sequence_length: (optional) Sequence lengths for the batch entries
    380         in memory.  If provided, the memory tensor rows are masked with zeros
    381         for values past the respective sequence lengths.
    382       scale: Python boolean.  Whether to scale the energy term.
    383       probability_fn: (optional) A `callable`.  Converts the score to
    384         probabilities.  The default is @{tf.nn.softmax}. Other options include
    385         @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
    386         Its signature should be: `probabilities = probability_fn(score)`.
    387       score_mask_value: (optional) The mask value for score before passing into
    388         `probability_fn`. The default is -inf. Only used if
    389         `memory_sequence_length` is not None.
    390       dtype: The data type for the memory layer of the attention mechanism.
    391       name: Name to use when creating ops.
    392     """
    393     # For LuongAttention, we only transform the memory layer; thus
    394     # num_units **must** match expected the query depth.
    395     if probability_fn is None:
    396       probability_fn = nn_ops.softmax
    397     if dtype is None:
    398       dtype = dtypes.float32
    399     wrapped_probability_fn = lambda score, _: probability_fn(score)
    400     super(LuongAttention, self).__init__(
    401         query_layer=None,
    402         memory_layer=layers_core.Dense(
    403             num_units, name="memory_layer", use_bias=False, dtype=dtype),
    404         memory=memory,
    405         probability_fn=wrapped_probability_fn,
    406         memory_sequence_length=memory_sequence_length,
    407         score_mask_value=score_mask_value,
    408         name=name)
    409     self._num_units = num_units
    410     self._scale = scale
    411     self._name = name
    412 
    413   def __call__(self, query, state):
    414     """Score the query based on the keys and values.
    415 
    416     Args:
    417       query: Tensor of dtype matching `self.values` and shape
    418         `[batch_size, query_depth]`.
    419       state: Tensor of dtype matching `self.values` and shape
    420         `[batch_size, alignments_size]`
    421         (`alignments_size` is memory's `max_time`).
    422 
    423     Returns:
    424       alignments: Tensor of dtype matching `self.values` and shape
    425         `[batch_size, alignments_size]` (`alignments_size` is memory's
    426         `max_time`).
    427     """
    428     with variable_scope.variable_scope(None, "luong_attention", [query]):
    429       score = _luong_score(query, self._keys, self._scale)
    430     alignments = self._probability_fn(score, state)
    431     next_state = alignments
    432     return alignments, next_state
    433 
    434 
    435 def _bahdanau_score(processed_query, keys, normalize):
    436   """Implements Bahdanau-style (additive) scoring function.
    437 
    438   This attention has two forms.  The first is Bhandanau attention,
    439   as described in:
    440 
    441   Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
    442   "Neural Machine Translation by Jointly Learning to Align and Translate."
    443   ICLR 2015. https://arxiv.org/abs/1409.0473
    444 
    445   The second is the normalized form.  This form is inspired by the
    446   weight normalization article:
    447 
    448   Tim Salimans, Diederik P. Kingma.
    449   "Weight Normalization: A Simple Reparameterization to Accelerate
    450    Training of Deep Neural Networks."
    451   https://arxiv.org/abs/1602.07868
    452 
    453   To enable the second form, set `normalize=True`.
    454 
    455   Args:
    456     processed_query: Tensor, shape `[batch_size, num_units]` to compare to keys.
    457     keys: Processed memory, shape `[batch_size, max_time, num_units]`.
    458     normalize: Whether to normalize the score function.
    459 
    460   Returns:
    461     A `[batch_size, max_time]` tensor of unnormalized score values.
    462   """
    463   dtype = processed_query.dtype
    464   # Get the number of hidden units from the trailing dimension of keys
    465   num_units = keys.shape[2].value or array_ops.shape(keys)[2]
    466   # Reshape from [batch_size, ...] to [batch_size, 1, ...] for broadcasting.
    467   processed_query = array_ops.expand_dims(processed_query, 1)
    468   v = variable_scope.get_variable(
    469       "attention_v", [num_units], dtype=dtype)
    470   if normalize:
    471     # Scalar used in weight normalization
    472     g = variable_scope.get_variable(
    473         "attention_g", dtype=dtype,
    474         initializer=math.sqrt((1. / num_units)))
    475     # Bias added prior to the nonlinearity
    476     b = variable_scope.get_variable(
    477         "attention_b", [num_units], dtype=dtype,
    478         initializer=init_ops.zeros_initializer())
    479     # normed_v = g * v / ||v||
    480     normed_v = g * v * math_ops.rsqrt(
    481         math_ops.reduce_sum(math_ops.square(v)))
    482     return math_ops.reduce_sum(
    483         normed_v * math_ops.tanh(keys + processed_query + b), [2])
    484   else:
    485     return math_ops.reduce_sum(v * math_ops.tanh(keys + processed_query), [2])
    486 
    487 
    488 class BahdanauAttention(_BaseAttentionMechanism):
    489   """Implements Bahdanau-style (additive) attention.
    490 
    491   This attention has two forms.  The first is Bahdanau attention,
    492   as described in:
    493 
    494   Dzmitry Bahdanau, Kyunghyun Cho, Yoshua Bengio.
    495   "Neural Machine Translation by Jointly Learning to Align and Translate."
    496   ICLR 2015. https://arxiv.org/abs/1409.0473
    497 
    498   The second is the normalized form.  This form is inspired by the
    499   weight normalization article:
    500 
    501   Tim Salimans, Diederik P. Kingma.
    502   "Weight Normalization: A Simple Reparameterization to Accelerate
    503    Training of Deep Neural Networks."
    504   https://arxiv.org/abs/1602.07868
    505 
    506   To enable the second form, construct the object with parameter
    507   `normalize=True`.
    508   """
    509 
    510   def __init__(self,
    511                num_units,
    512                memory,
    513                memory_sequence_length=None,
    514                normalize=False,
    515                probability_fn=None,
    516                score_mask_value=None,
    517                dtype=None,
    518                name="BahdanauAttention"):
    519     """Construct the Attention mechanism.
    520 
    521     Args:
    522       num_units: The depth of the query mechanism.
    523       memory: The memory to query; usually the output of an RNN encoder.  This
    524         tensor should be shaped `[batch_size, max_time, ...]`.
    525       memory_sequence_length (optional): Sequence lengths for the batch entries
    526         in memory.  If provided, the memory tensor rows are masked with zeros
    527         for values past the respective sequence lengths.
    528       normalize: Python boolean.  Whether to normalize the energy term.
    529       probability_fn: (optional) A `callable`.  Converts the score to
    530         probabilities.  The default is @{tf.nn.softmax}. Other options include
    531         @{tf.contrib.seq2seq.hardmax} and @{tf.contrib.sparsemax.sparsemax}.
    532         Its signature should be: `probabilities = probability_fn(score)`.
    533       score_mask_value: (optional): The mask value for score before passing into
    534         `probability_fn`. The default is -inf. Only used if
    535         `memory_sequence_length` is not None.
    536       dtype: The data type for the query and memory layers of the attention
    537         mechanism.
    538       name: Name to use when creating ops.
    539     """
    540     if probability_fn is None:
    541       probability_fn = nn_ops.softmax
    542     if dtype is None:
    543       dtype = dtypes.float32
    544     wrapped_probability_fn = lambda score, _: probability_fn(score)
    545     super(BahdanauAttention, self).__init__(
    546         query_layer=layers_core.Dense(
    547             num_units, name="query_layer", use_bias=False, dtype=dtype),
    548         memory_layer=layers_core.Dense(
    549             num_units, name="memory_layer", use_bias=False, dtype=dtype),
    550         memory=memory,
    551         probability_fn=wrapped_probability_fn,
    552         memory_sequence_length=memory_sequence_length,
    553         score_mask_value=score_mask_value,
    554         name=name)
    555     self._num_units = num_units
    556     self._normalize = normalize
    557     self._name = name
    558 
    559   def __call__(self, query, state):
    560     """Score the query based on the keys and values.
    561 
    562     Args:
    563       query: Tensor of dtype matching `self.values` and shape
    564         `[batch_size, query_depth]`.
    565       state: Tensor of dtype matching `self.values` and shape
    566         `[batch_size, alignments_size]`
    567         (`alignments_size` is memory's `max_time`).
    568 
    569     Returns:
    570       alignments: Tensor of dtype matching `self.values` and shape
    571         `[batch_size, alignments_size]` (`alignments_size` is memory's
    572         `max_time`).
    573     """
    574     with variable_scope.variable_scope(None, "bahdanau_attention", [query]):
    575       processed_query = self.query_layer(query) if self.query_layer else query
    576       score = _bahdanau_score(processed_query, self._keys, self._normalize)
    577     alignments = self._probability_fn(score, state)
    578     next_state = alignments
    579     return alignments, next_state
    580 
    581 
    582 def safe_cumprod(x, *args, **kwargs):
    583   """Computes cumprod of x in logspace using cumsum to avoid underflow.
    584 
    585   The cumprod function and its gradient can result in numerical instabilities
    586   when its argument has very small and/or zero values.  As long as the argument
    587   is all positive, we can instead compute the cumulative product as
    588   exp(cumsum(log(x))).  This function can be called identically to tf.cumprod.
    589 
    590   Args:
    591     x: Tensor to take the cumulative product of.
    592     *args: Passed on to cumsum; these are identical to those in cumprod.
    593     **kwargs: Passed on to cumsum; these are identical to those in cumprod.
    594   Returns:
    595     Cumulative product of x.
    596   """
    597   with ops.name_scope(None, "SafeCumprod", [x]):
    598     x = ops.convert_to_tensor(x, name="x")
    599     tiny = np.finfo(x.dtype.as_numpy_dtype).tiny
    600     return math_ops.exp(math_ops.cumsum(
    601         math_ops.log(clip_ops.clip_by_value(x, tiny, 1)), *args, **kwargs))
    602 
    603 
    604 def monotonic_attention(p_choose_i, previous_attention, mode):
    605   """Compute monotonic attention distribution from choosing probabilities.
    606 
    607   Monotonic attention implies that the input sequence is processed in an
    608   explicitly left-to-right manner when generating the output sequence.  In
    609   addition, once an input sequence element is attended to at a given output
    610   timestep, elements occurring before it cannot be attended to at subsequent
    611   output timesteps.  This function generates attention distributions according
    612   to these assumptions.  For more information, see ``Online and Linear-Time
    613   Attention by Enforcing Monotonic Alignments''.
    614 
    615   Args:
    616     p_choose_i: Probability of choosing input sequence/memory element i.  Should
    617       be of shape (batch_size, input_sequence_length), and should all be in the
    618       range [0, 1].
    619     previous_attention: The attention distribution from the previous output
    620       timestep.  Should be of shape (batch_size, input_sequence_length).  For
    621       the first output timestep, preevious_attention[n] should be [1, 0, 0, ...,
    622       0] for all n in [0, ... batch_size - 1].
    623     mode: How to compute the attention distribution.  Must be one of
    624       'recursive', 'parallel', or 'hard'.
    625         * 'recursive' uses tf.scan to recursively compute the distribution.
    626           This is slowest but is exact, general, and does not suffer from
    627           numerical instabilities.
    628         * 'parallel' uses parallelized cumulative-sum and cumulative-product
    629           operations to compute a closed-form solution to the recurrence
    630           relation defining the attention distribution.  This makes it more
    631           efficient than 'recursive', but it requires numerical checks which
    632           make the distribution non-exact.  This can be a problem in particular
    633           when input_sequence_length is long and/or p_choose_i has entries very
    634           close to 0 or 1.
    635         * 'hard' requires that the probabilities in p_choose_i are all either 0
    636           or 1, and subsequently uses a more efficient and exact solution.
    637 
    638   Returns:
    639     A tensor of shape (batch_size, input_sequence_length) representing the
    640     attention distributions for each sequence in the batch.
    641 
    642   Raises:
    643     ValueError: mode is not one of 'recursive', 'parallel', 'hard'.
    644   """
    645   # Force things to be tensors
    646   p_choose_i = ops.convert_to_tensor(p_choose_i, name="p_choose_i")
    647   previous_attention = ops.convert_to_tensor(
    648       previous_attention, name="previous_attention")
    649   if mode == "recursive":
    650     # Use .shape[0].value when it's not None, or fall back on symbolic shape
    651     batch_size = p_choose_i.shape[0].value or array_ops.shape(p_choose_i)[0]
    652     # Compute [1, 1 - p_choose_i[0], 1 - p_choose_i[1], ..., 1 - p_choose_i[-2]]
    653     shifted_1mp_choose_i = array_ops.concat(
    654         [array_ops.ones((batch_size, 1)), 1 - p_choose_i[:, :-1]], 1)
    655     # Compute attention distribution recursively as
    656     # q[i] = (1 - p_choose_i[i])*q[i - 1] + previous_attention[i]
    657     # attention[i] = p_choose_i[i]*q[i]
    658     attention = p_choose_i*array_ops.transpose(functional_ops.scan(
    659         # Need to use reshape to remind TF of the shape between loop iterations
    660         lambda x, yz: array_ops.reshape(yz[0]*x + yz[1], (batch_size,)),
    661         # Loop variables yz[0] and yz[1]
    662         [array_ops.transpose(shifted_1mp_choose_i),
    663          array_ops.transpose(previous_attention)],
    664         # Initial value of x is just zeros
    665         array_ops.zeros((batch_size,))))
    666   elif mode == "parallel":
    667     # safe_cumprod computes cumprod in logspace with numeric checks
    668     cumprod_1mp_choose_i = safe_cumprod(1 - p_choose_i, axis=1, exclusive=True)
    669     # Compute recurrence relation solution
    670     attention = p_choose_i*cumprod_1mp_choose_i*math_ops.cumsum(
    671         previous_attention /
    672         # Clip cumprod_1mp to avoid divide-by-zero
    673         clip_ops.clip_by_value(cumprod_1mp_choose_i, 1e-10, 1.), axis=1)
    674   elif mode == "hard":
    675     # Remove any probabilities before the index chosen last time step
    676     p_choose_i *= math_ops.cumsum(previous_attention, axis=1)
    677     # Now, use exclusive cumprod to remove probabilities after the first
    678     # chosen index, like so:
    679     # p_choose_i = [0, 0, 0, 1, 1, 0, 1, 1]
    680     # cumprod(1 - p_choose_i, exclusive=True) = [1, 1, 1, 1, 0, 0, 0, 0]
    681     # Product of above: [0, 0, 0, 1, 0, 0, 0, 0]
    682     attention = p_choose_i*math_ops.cumprod(
    683         1 - p_choose_i, axis=1, exclusive=True)
    684   else:
    685     raise ValueError("mode must be 'recursive', 'parallel', or 'hard'.")
    686   return attention
    687 
    688 
    689 def _monotonic_probability_fn(score, previous_alignments, sigmoid_noise, mode,
    690                               seed=None):
    691   """Attention probability function for monotonic attention.
    692 
    693   Takes in unnormalized attention scores, adds pre-sigmoid noise to encourage
    694   the model to make discrete attention decisions, passes them through a sigmoid
    695   to obtain "choosing" probabilities, and then calls monotonic_attention to
    696   obtain the attention distribution.  For more information, see
    697 
    698   Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
    699   "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
    700   ICML 2017.  https://arxiv.org/abs/1704.00784
    701 
    702   Args:
    703     score: Unnormalized attention scores, shape `[batch_size, alignments_size]`
    704     previous_alignments: Previous attention distribution, shape
    705       `[batch_size, alignments_size]`
    706     sigmoid_noise: Standard deviation of pre-sigmoid noise.  Setting this larger
    707       than 0 will encourage the model to produce large attention scores,
    708       effectively making the choosing probabilities discrete and the resulting
    709       attention distribution one-hot.  It should be set to 0 at test-time, and
    710       when hard attention is not desired.
    711     mode: How to compute the attention distribution.  Must be one of
    712       'recursive', 'parallel', or 'hard'.  See the docstring for
    713       `tf.contrib.seq2seq.monotonic_attention` for more information.
    714     seed: (optional) Random seed for pre-sigmoid noise.
    715 
    716   Returns:
    717     A `[batch_size, alignments_size]`-shape tensor corresponding to the
    718     resulting attention distribution.
    719   """
    720   # Optionally add pre-sigmoid noise to the scores
    721   if sigmoid_noise > 0:
    722     noise = random_ops.random_normal(array_ops.shape(score), dtype=score.dtype,
    723                                      seed=seed)
    724     score += sigmoid_noise*noise
    725   # Compute "choosing" probabilities from the attention scores
    726   if mode == "hard":
    727     # When mode is hard, use a hard sigmoid
    728     p_choose_i = math_ops.cast(score > 0, score.dtype)
    729   else:
    730     p_choose_i = math_ops.sigmoid(score)
    731   # Convert from choosing probabilities to attention distribution
    732   return monotonic_attention(p_choose_i, previous_alignments, mode)
    733 
    734 
    735 class _BaseMonotonicAttentionMechanism(_BaseAttentionMechanism):
    736   """Base attention mechanism for monotonic attention.
    737 
    738   Simply overrides the initial_alignments function to provide a dirac
    739   distribution,which is needed in order for the monotonic attention
    740   distributions to have the correct behavior.
    741   """
    742 
    743   def initial_alignments(self, batch_size, dtype):
    744     """Creates the initial alignment values for the monotonic attentions.
    745 
    746     Initializes to dirac distributions, i.e. [1, 0, 0, ...memory length..., 0]
    747     for all entries in the batch.
    748 
    749     Args:
    750       batch_size: `int32` scalar, the batch_size.
    751       dtype: The `dtype`.
    752 
    753     Returns:
    754       A `dtype` tensor shaped `[batch_size, alignments_size]`
    755       (`alignments_size` is the values' `max_time`).
    756     """
    757     max_time = self._alignments_size
    758     return array_ops.one_hot(
    759         array_ops.zeros((batch_size,), dtype=dtypes.int32), max_time,
    760         dtype=dtype)
    761 
    762 
    763 class BahdanauMonotonicAttention(_BaseMonotonicAttentionMechanism):
    764   """Monotonic attention mechanism with Bahadanau-style energy function.
    765 
    766   This type of attention encorces a monotonic constraint on the attention
    767   distributions; that is once the model attends to a given point in the memory
    768   it can't attend to any prior points at subsequence output timesteps.  It
    769   achieves this by using the _monotonic_probability_fn instead of softmax to
    770   construct its attention distributions.  Since the attention scores are passed
    771   through a sigmoid, a learnable scalar bias parameter is applied after the
    772   score function and before the sigmoid.  Otherwise, it is equivalent to
    773   BahdanauAttention.  This approach is proposed in
    774 
    775   Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
    776   "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
    777   ICML 2017.  https://arxiv.org/abs/1704.00784
    778   """
    779 
    780   def __init__(self,
    781                num_units,
    782                memory,
    783                memory_sequence_length=None,
    784                normalize=False,
    785                score_mask_value=None,
    786                sigmoid_noise=0.,
    787                sigmoid_noise_seed=None,
    788                score_bias_init=0.,
    789                mode="parallel",
    790                dtype=None,
    791                name="BahdanauMonotonicAttention"):
    792     """Construct the Attention mechanism.
    793 
    794     Args:
    795       num_units: The depth of the query mechanism.
    796       memory: The memory to query; usually the output of an RNN encoder.  This
    797         tensor should be shaped `[batch_size, max_time, ...]`.
    798       memory_sequence_length (optional): Sequence lengths for the batch entries
    799         in memory.  If provided, the memory tensor rows are masked with zeros
    800         for values past the respective sequence lengths.
    801       normalize: Python boolean.  Whether to normalize the energy term.
    802       score_mask_value: (optional): The mask value for score before passing into
    803         `probability_fn`. The default is -inf. Only used if
    804         `memory_sequence_length` is not None.
    805       sigmoid_noise: Standard deviation of pre-sigmoid noise.  See the docstring
    806         for `_monotonic_probability_fn` for more information.
    807       sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
    808       score_bias_init: Initial value for score bias scalar.  It's recommended to
    809         initialize this to a negative value when the length of the memory is
    810         large.
    811       mode: How to compute the attention distribution.  Must be one of
    812         'recursive', 'parallel', or 'hard'.  See the docstring for
    813         `tf.contrib.seq2seq.monotonic_attention` for more information.
    814       dtype: The data type for the query and memory layers of the attention
    815         mechanism.
    816       name: Name to use when creating ops.
    817     """
    818     # Set up the monotonic probability fn with supplied parameters
    819     if dtype is None:
    820       dtype = dtypes.float32
    821     wrapped_probability_fn = functools.partial(
    822         _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
    823         seed=sigmoid_noise_seed)
    824     super(BahdanauMonotonicAttention, self).__init__(
    825         query_layer=layers_core.Dense(
    826             num_units, name="query_layer", use_bias=False, dtype=dtype),
    827         memory_layer=layers_core.Dense(
    828             num_units, name="memory_layer", use_bias=False, dtype=dtype),
    829         memory=memory,
    830         probability_fn=wrapped_probability_fn,
    831         memory_sequence_length=memory_sequence_length,
    832         score_mask_value=score_mask_value,
    833         name=name)
    834     self._num_units = num_units
    835     self._normalize = normalize
    836     self._name = name
    837     self._score_bias_init = score_bias_init
    838 
    839   def __call__(self, query, state):
    840     """Score the query based on the keys and values.
    841 
    842     Args:
    843       query: Tensor of dtype matching `self.values` and shape
    844         `[batch_size, query_depth]`.
    845       state: Tensor of dtype matching `self.values` and shape
    846         `[batch_size, alignments_size]`
    847         (`alignments_size` is memory's `max_time`).
    848 
    849     Returns:
    850       alignments: Tensor of dtype matching `self.values` and shape
    851         `[batch_size, alignments_size]` (`alignments_size` is memory's
    852         `max_time`).
    853     """
    854     with variable_scope.variable_scope(
    855         None, "bahdanau_monotonic_attention", [query]):
    856       processed_query = self.query_layer(query) if self.query_layer else query
    857       score = _bahdanau_score(processed_query, self._keys, self._normalize)
    858       score_bias = variable_scope.get_variable(
    859           "attention_score_bias", dtype=processed_query.dtype,
    860           initializer=self._score_bias_init)
    861       score += score_bias
    862     alignments = self._probability_fn(score, state)
    863     next_state = alignments
    864     return alignments, next_state
    865 
    866 
    867 class LuongMonotonicAttention(_BaseMonotonicAttentionMechanism):
    868   """Monotonic attention mechanism with Luong-style energy function.
    869 
    870   This type of attention encorces a monotonic constraint on the attention
    871   distributions; that is once the model attends to a given point in the memory
    872   it can't attend to any prior points at subsequence output timesteps.  It
    873   achieves this by using the _monotonic_probability_fn instead of softmax to
    874   construct its attention distributions.  Otherwise, it is equivalent to
    875   LuongAttention.  This approach is proposed in
    876 
    877   Colin Raffel, Minh-Thang Luong, Peter J. Liu, Ron J. Weiss, Douglas Eck,
    878   "Online and Linear-Time Attention by Enforcing Monotonic Alignments."
    879   ICML 2017.  https://arxiv.org/abs/1704.00784
    880   """
    881 
    882   def __init__(self,
    883                num_units,
    884                memory,
    885                memory_sequence_length=None,
    886                scale=False,
    887                score_mask_value=None,
    888                sigmoid_noise=0.,
    889                sigmoid_noise_seed=None,
    890                score_bias_init=0.,
    891                mode="parallel",
    892                dtype=None,
    893                name="LuongMonotonicAttention"):
    894     """Construct the Attention mechanism.
    895 
    896     Args:
    897       num_units: The depth of the query mechanism.
    898       memory: The memory to query; usually the output of an RNN encoder.  This
    899         tensor should be shaped `[batch_size, max_time, ...]`.
    900       memory_sequence_length (optional): Sequence lengths for the batch entries
    901         in memory.  If provided, the memory tensor rows are masked with zeros
    902         for values past the respective sequence lengths.
    903       scale: Python boolean.  Whether to scale the energy term.
    904       score_mask_value: (optional): The mask value for score before passing into
    905         `probability_fn`. The default is -inf. Only used if
    906         `memory_sequence_length` is not None.
    907       sigmoid_noise: Standard deviation of pre-sigmoid noise.  See the docstring
    908         for `_monotonic_probability_fn` for more information.
    909       sigmoid_noise_seed: (optional) Random seed for pre-sigmoid noise.
    910       score_bias_init: Initial value for score bias scalar.  It's recommended to
    911         initialize this to a negative value when the length of the memory is
    912         large.
    913       mode: How to compute the attention distribution.  Must be one of
    914         'recursive', 'parallel', or 'hard'.  See the docstring for
    915         `tf.contrib.seq2seq.monotonic_attention` for more information.
    916       dtype: The data type for the query and memory layers of the attention
    917         mechanism.
    918       name: Name to use when creating ops.
    919     """
    920     # Set up the monotonic probability fn with supplied parameters
    921     if dtype is None:
    922       dtype = dtypes.float32
    923     wrapped_probability_fn = functools.partial(
    924         _monotonic_probability_fn, sigmoid_noise=sigmoid_noise, mode=mode,
    925         seed=sigmoid_noise_seed)
    926     super(LuongMonotonicAttention, self).__init__(
    927         query_layer=None,
    928         memory_layer=layers_core.Dense(
    929             num_units, name="memory_layer", use_bias=False, dtype=dtype),
    930         memory=memory,
    931         probability_fn=wrapped_probability_fn,
    932         memory_sequence_length=memory_sequence_length,
    933         score_mask_value=score_mask_value,
    934         name=name)
    935     self._num_units = num_units
    936     self._scale = scale
    937     self._score_bias_init = score_bias_init
    938     self._name = name
    939 
    940   def __call__(self, query, state):
    941     """Score the query based on the keys and values.
    942 
    943     Args:
    944       query: Tensor of dtype matching `self.values` and shape
    945         `[batch_size, query_depth]`.
    946       state: Tensor of dtype matching `self.values` and shape
    947         `[batch_size, alignments_size]`
    948         (`alignments_size` is memory's `max_time`).
    949 
    950     Returns:
    951       alignments: Tensor of dtype matching `self.values` and shape
    952         `[batch_size, alignments_size]` (`alignments_size` is memory's
    953         `max_time`).
    954     """
    955     with variable_scope.variable_scope(None, "luong_monotonic_attention",
    956                                        [query]):
    957       score = _luong_score(query, self._keys, self._scale)
    958       score_bias = variable_scope.get_variable(
    959           "attention_score_bias", dtype=query.dtype,
    960           initializer=self._score_bias_init)
    961       score += score_bias
    962     alignments = self._probability_fn(score, state)
    963     next_state = alignments
    964     return alignments, next_state
    965 
    966 
    967 class AttentionWrapperState(
    968     collections.namedtuple("AttentionWrapperState",
    969                            ("cell_state", "attention", "time", "alignments",
    970                             "alignment_history", "attention_state"))):
    971   """`namedtuple` storing the state of a `AttentionWrapper`.
    972 
    973   Contains:
    974 
    975     - `cell_state`: The state of the wrapped `RNNCell` at the previous time
    976       step.
    977     - `attention`: The attention emitted at the previous time step.
    978     - `time`: int32 scalar containing the current time step.
    979     - `alignments`: A single or tuple of `Tensor`(s) containing the alignments
    980        emitted at the previous time step for each attention mechanism.
    981     - `alignment_history`: (if enabled) a single or tuple of `TensorArray`(s)
    982        containing alignment matrices from all time steps for each attention
    983        mechanism. Call `stack()` on each to convert to a `Tensor`.
    984     - `attention_state`: A single or tuple of nested objects
    985        containing attention mechanism state for each attention mechanism.
    986        The objects may contain Tensors or TensorArrays.
    987   """
    988 
    989   def clone(self, **kwargs):
    990     """Clone this object, overriding components provided by kwargs.
    991 
    992     The new state fields' shape must match original state fields' shape. This
    993     will be validated, and original fields' shape will be propagated to new
    994     fields.
    995 
    996     Example:
    997 
    998     ```python
    999     initial_state = attention_wrapper.zero_state(dtype=..., batch_size=...)
   1000     initial_state = initial_state.clone(cell_state=encoder_state)
   1001     ```
   1002 
   1003     Args:
   1004       **kwargs: Any properties of the state object to replace in the returned
   1005         `AttentionWrapperState`.
   1006 
   1007     Returns:
   1008       A new `AttentionWrapperState` whose properties are the same as
   1009       this one, except any overridden properties as provided in `kwargs`.
   1010     """
   1011     def with_same_shape(old, new):
   1012       """Check and set new tensor's shape."""
   1013       if isinstance(old, ops.Tensor) and isinstance(new, ops.Tensor):
   1014         return tensor_util.with_same_shape(old, new)
   1015       return new
   1016 
   1017     return nest.map_structure(
   1018         with_same_shape,
   1019         self,
   1020         super(AttentionWrapperState, self)._replace(**kwargs))
   1021 
   1022 
   1023 def hardmax(logits, name=None):
   1024   """Returns batched one-hot vectors.
   1025 
   1026   The depth index containing the `1` is that of the maximum logit value.
   1027 
   1028   Args:
   1029     logits: A batch tensor of logit values.
   1030     name: Name to use when creating ops.
   1031   Returns:
   1032     A batched one-hot tensor.
   1033   """
   1034   with ops.name_scope(name, "Hardmax", [logits]):
   1035     logits = ops.convert_to_tensor(logits, name="logits")
   1036     if logits.get_shape()[-1].value is not None:
   1037       depth = logits.get_shape()[-1].value
   1038     else:
   1039       depth = array_ops.shape(logits)[-1]
   1040     return array_ops.one_hot(
   1041         math_ops.argmax(logits, -1), depth, dtype=logits.dtype)
   1042 
   1043 
   1044 def _compute_attention(attention_mechanism, cell_output, attention_state,
   1045                        attention_layer):
   1046   """Computes the attention and alignments for a given attention_mechanism."""
   1047   alignments, next_attention_state = attention_mechanism(
   1048       cell_output, state=attention_state)
   1049 
   1050   # Reshape from [batch_size, memory_time] to [batch_size, 1, memory_time]
   1051   expanded_alignments = array_ops.expand_dims(alignments, 1)
   1052   # Context is the inner product of alignments and values along the
   1053   # memory time dimension.
   1054   # alignments shape is
   1055   #   [batch_size, 1, memory_time]
   1056   # attention_mechanism.values shape is
   1057   #   [batch_size, memory_time, memory_size]
   1058   # the batched matmul is over memory_time, so the output shape is
   1059   #   [batch_size, 1, memory_size].
   1060   # we then squeeze out the singleton dim.
   1061   context = math_ops.matmul(expanded_alignments, attention_mechanism.values)
   1062   context = array_ops.squeeze(context, [1])
   1063 
   1064   if attention_layer is not None:
   1065     attention = attention_layer(array_ops.concat([cell_output, context], 1))
   1066   else:
   1067     attention = context
   1068 
   1069   return attention, alignments, next_attention_state
   1070 
   1071 
   1072 class AttentionWrapper(rnn_cell_impl.RNNCell):
   1073   """Wraps another `RNNCell` with attention.
   1074   """
   1075 
   1076   def __init__(self,
   1077                cell,
   1078                attention_mechanism,
   1079                attention_layer_size=None,
   1080                alignment_history=False,
   1081                cell_input_fn=None,
   1082                output_attention=True,
   1083                initial_cell_state=None,
   1084                name=None):
   1085     """Construct the `AttentionWrapper`.
   1086 
   1087     **NOTE** If you are using the `BeamSearchDecoder` with a cell wrapped in
   1088     `AttentionWrapper`, then you must ensure that:
   1089 
   1090     - The encoder output has been tiled to `beam_width` via
   1091       @{tf.contrib.seq2seq.tile_batch} (NOT `tf.tile`).
   1092     - The `batch_size` argument passed to the `zero_state` method of this
   1093       wrapper is equal to `true_batch_size * beam_width`.
   1094     - The initial state created with `zero_state` above contains a
   1095       `cell_state` value containing properly tiled final state from the
   1096       encoder.
   1097 
   1098     An example:
   1099 
   1100     ```
   1101     tiled_encoder_outputs = tf.contrib.seq2seq.tile_batch(
   1102         encoder_outputs, multiplier=beam_width)
   1103     tiled_encoder_final_state = tf.conrib.seq2seq.tile_batch(
   1104         encoder_final_state, multiplier=beam_width)
   1105     tiled_sequence_length = tf.contrib.seq2seq.tile_batch(
   1106         sequence_length, multiplier=beam_width)
   1107     attention_mechanism = MyFavoriteAttentionMechanism(
   1108         num_units=attention_depth,
   1109         memory=tiled_inputs,
   1110         memory_sequence_length=tiled_sequence_length)
   1111     attention_cell = AttentionWrapper(cell, attention_mechanism, ...)
   1112     decoder_initial_state = attention_cell.zero_state(
   1113         dtype, batch_size=true_batch_size * beam_width)
   1114     decoder_initial_state = decoder_initial_state.clone(
   1115         cell_state=tiled_encoder_final_state)
   1116     ```
   1117 
   1118     Args:
   1119       cell: An instance of `RNNCell`.
   1120       attention_mechanism: A list of `AttentionMechanism` instances or a single
   1121         instance.
   1122       attention_layer_size: A list of Python integers or a single Python
   1123         integer, the depth of the attention (output) layer(s). If None
   1124         (default), use the context as attention at each time step. Otherwise,
   1125         feed the context and cell output into the attention layer to generate
   1126         attention at each time step. If attention_mechanism is a list,
   1127         attention_layer_size must be a list of the same length.
   1128       alignment_history: Python boolean, whether to store alignment history
   1129         from all time steps in the final output state (currently stored as a
   1130         time major `TensorArray` on which you must call `stack()`).
   1131       cell_input_fn: (optional) A `callable`.  The default is:
   1132         `lambda inputs, attention: array_ops.concat([inputs, attention], -1)`.
   1133       output_attention: Python bool.  If `True` (default), the output at each
   1134         time step is the attention value.  This is the behavior of Luong-style
   1135         attention mechanisms.  If `False`, the output at each time step is
   1136         the output of `cell`.  This is the beahvior of Bhadanau-style
   1137         attention mechanisms.  In both cases, the `attention` tensor is
   1138         propagated to the next time step via the state and is used there.
   1139         This flag only controls whether the attention mechanism is propagated
   1140         up to the next cell in an RNN stack or to the top RNN output.
   1141       initial_cell_state: The initial state value to use for the cell when
   1142         the user calls `zero_state()`.  Note that if this value is provided
   1143         now, and the user uses a `batch_size` argument of `zero_state` which
   1144         does not match the batch size of `initial_cell_state`, proper
   1145         behavior is not guaranteed.
   1146       name: Name to use when creating ops.
   1147 
   1148     Raises:
   1149       TypeError: `attention_layer_size` is not None and (`attention_mechanism`
   1150         is a list but `attention_layer_size` is not; or vice versa).
   1151       ValueError: if `attention_layer_size` is not None, `attention_mechanism`
   1152         is a list, and its length does not match that of `attention_layer_size`.
   1153     """
   1154     super(AttentionWrapper, self).__init__(name=name)
   1155     if not rnn_cell_impl._like_rnncell(cell):  # pylint: disable=protected-access
   1156       raise TypeError(
   1157           "cell must be an RNNCell, saw type: %s" % type(cell).__name__)
   1158     if isinstance(attention_mechanism, (list, tuple)):
   1159       self._is_multi = True
   1160       attention_mechanisms = attention_mechanism
   1161       for attention_mechanism in attention_mechanisms:
   1162         if not isinstance(attention_mechanism, AttentionMechanism):
   1163           raise TypeError(
   1164               "attention_mechanism must contain only instances of "
   1165               "AttentionMechanism, saw type: %s"
   1166               % type(attention_mechanism).__name__)
   1167     else:
   1168       self._is_multi = False
   1169       if not isinstance(attention_mechanism, AttentionMechanism):
   1170         raise TypeError(
   1171             "attention_mechanism must be an AttentionMechanism or list of "
   1172             "multiple AttentionMechanism instances, saw type: %s"
   1173             % type(attention_mechanism).__name__)
   1174       attention_mechanisms = (attention_mechanism,)
   1175 
   1176     if cell_input_fn is None:
   1177       cell_input_fn = (
   1178           lambda inputs, attention: array_ops.concat([inputs, attention], -1))
   1179     else:
   1180       if not callable(cell_input_fn):
   1181         raise TypeError(
   1182             "cell_input_fn must be callable, saw type: %s"
   1183             % type(cell_input_fn).__name__)
   1184 
   1185     if attention_layer_size is not None:
   1186       attention_layer_sizes = tuple(
   1187           attention_layer_size
   1188           if isinstance(attention_layer_size, (list, tuple))
   1189           else (attention_layer_size,))
   1190       if len(attention_layer_sizes) != len(attention_mechanisms):
   1191         raise ValueError(
   1192             "If provided, attention_layer_size must contain exactly one "
   1193             "integer per attention_mechanism, saw: %d vs %d"
   1194             % (len(attention_layer_sizes), len(attention_mechanisms)))
   1195       self._attention_layers = tuple(
   1196           layers_core.Dense(
   1197               attention_layer_size,
   1198               name="attention_layer",
   1199               use_bias=False,
   1200               dtype=attention_mechanisms[i].dtype)
   1201           for i, attention_layer_size in enumerate(attention_layer_sizes))
   1202       self._attention_layer_size = sum(attention_layer_sizes)
   1203     else:
   1204       self._attention_layers = None
   1205       self._attention_layer_size = sum(
   1206           attention_mechanism.values.get_shape()[-1].value
   1207           for attention_mechanism in attention_mechanisms)
   1208 
   1209     self._cell = cell
   1210     self._attention_mechanisms = attention_mechanisms
   1211     self._cell_input_fn = cell_input_fn
   1212     self._output_attention = output_attention
   1213     self._alignment_history = alignment_history
   1214     with ops.name_scope(name, "AttentionWrapperInit"):
   1215       if initial_cell_state is None:
   1216         self._initial_cell_state = None
   1217       else:
   1218         final_state_tensor = nest.flatten(initial_cell_state)[-1]
   1219         state_batch_size = (
   1220             final_state_tensor.shape[0].value
   1221             or array_ops.shape(final_state_tensor)[0])
   1222         error_message = (
   1223             "When constructing AttentionWrapper %s: " % self._base_name +
   1224             "Non-matching batch sizes between the memory "
   1225             "(encoder output) and initial_cell_state.  Are you using "
   1226             "the BeamSearchDecoder?  You may need to tile your initial state "
   1227             "via the tf.contrib.seq2seq.tile_batch function with argument "
   1228             "multiple=beam_width.")
   1229         with ops.control_dependencies(
   1230             self._batch_size_checks(state_batch_size, error_message)):
   1231           self._initial_cell_state = nest.map_structure(
   1232               lambda s: array_ops.identity(s, name="check_initial_cell_state"),
   1233               initial_cell_state)
   1234 
   1235   def _batch_size_checks(self, batch_size, error_message):
   1236     return [check_ops.assert_equal(batch_size,
   1237                                    attention_mechanism.batch_size,
   1238                                    message=error_message)
   1239             for attention_mechanism in self._attention_mechanisms]
   1240 
   1241   def _item_or_tuple(self, seq):
   1242     """Returns `seq` as tuple or the singular element.
   1243 
   1244     Which is returned is determined by how the AttentionMechanism(s) were passed
   1245     to the constructor.
   1246 
   1247     Args:
   1248       seq: A non-empty sequence of items or generator.
   1249 
   1250     Returns:
   1251        Either the values in the sequence as a tuple if AttentionMechanism(s)
   1252        were passed to the constructor as a sequence or the singular element.
   1253     """
   1254     t = tuple(seq)
   1255     if self._is_multi:
   1256       return t
   1257     else:
   1258       return t[0]
   1259 
   1260   @property
   1261   def output_size(self):
   1262     if self._output_attention:
   1263       return self._attention_layer_size
   1264     else:
   1265       return self._cell.output_size
   1266 
   1267   @property
   1268   def state_size(self):
   1269     """The `state_size` property of `AttentionWrapper`.
   1270 
   1271     Returns:
   1272       An `AttentionWrapperState` tuple containing shapes used by this object.
   1273     """
   1274     return AttentionWrapperState(
   1275         cell_state=self._cell.state_size,
   1276         time=tensor_shape.TensorShape([]),
   1277         attention=self._attention_layer_size,
   1278         alignments=self._item_or_tuple(
   1279             a.alignments_size for a in self._attention_mechanisms),
   1280         attention_state=self._item_or_tuple(
   1281             a.state_size for a in self._attention_mechanisms),
   1282         alignment_history=self._item_or_tuple(
   1283             () for _ in self._attention_mechanisms))  # sometimes a TensorArray
   1284 
   1285   def zero_state(self, batch_size, dtype):
   1286     """Return an initial (zero) state tuple for this `AttentionWrapper`.
   1287 
   1288     **NOTE** Please see the initializer documentation for details of how
   1289     to call `zero_state` if using an `AttentionWrapper` with a
   1290     `BeamSearchDecoder`.
   1291 
   1292     Args:
   1293       batch_size: `0D` integer tensor: the batch size.
   1294       dtype: The internal state data type.
   1295 
   1296     Returns:
   1297       An `AttentionWrapperState` tuple containing zeroed out tensors and,
   1298       possibly, empty `TensorArray` objects.
   1299 
   1300     Raises:
   1301       ValueError: (or, possibly at runtime, InvalidArgument), if
   1302         `batch_size` does not match the output size of the encoder passed
   1303         to the wrapper object at initialization time.
   1304     """
   1305     with ops.name_scope(type(self).__name__ + "ZeroState", values=[batch_size]):
   1306       if self._initial_cell_state is not None:
   1307         cell_state = self._initial_cell_state
   1308       else:
   1309         cell_state = self._cell.zero_state(batch_size, dtype)
   1310       error_message = (
   1311           "When calling zero_state of AttentionWrapper %s: " % self._base_name +
   1312           "Non-matching batch sizes between the memory "
   1313           "(encoder output) and the requested batch size.  Are you using "
   1314           "the BeamSearchDecoder?  If so, make sure your encoder output has "
   1315           "been tiled to beam_width via tf.contrib.seq2seq.tile_batch, and "
   1316           "the batch_size= argument passed to zero_state is "
   1317           "batch_size * beam_width.")
   1318       with ops.control_dependencies(
   1319           self._batch_size_checks(batch_size, error_message)):
   1320         cell_state = nest.map_structure(
   1321             lambda s: array_ops.identity(s, name="checked_cell_state"),
   1322             cell_state)
   1323       return AttentionWrapperState(
   1324           cell_state=cell_state,
   1325           time=array_ops.zeros([], dtype=dtypes.int32),
   1326           attention=_zero_state_tensors(self._attention_layer_size, batch_size,
   1327                                         dtype),
   1328           alignments=self._item_or_tuple(
   1329               attention_mechanism.initial_alignments(batch_size, dtype)
   1330               for attention_mechanism in self._attention_mechanisms),
   1331           attention_state=self._item_or_tuple(
   1332               attention_mechanism.initial_state(batch_size, dtype)
   1333               for attention_mechanism in self._attention_mechanisms),
   1334           alignment_history=self._item_or_tuple(
   1335               tensor_array_ops.TensorArray(dtype=dtype, size=0,
   1336                                            dynamic_size=True)
   1337               if self._alignment_history else ()
   1338               for _ in self._attention_mechanisms))
   1339 
   1340   def call(self, inputs, state):
   1341     """Perform a step of attention-wrapped RNN.
   1342 
   1343     - Step 1: Mix the `inputs` and previous step's `attention` output via
   1344       `cell_input_fn`.
   1345     - Step 2: Call the wrapped `cell` with this input and its previous state.
   1346     - Step 3: Score the cell's output with `attention_mechanism`.
   1347     - Step 4: Calculate the alignments by passing the score through the
   1348       `normalizer`.
   1349     - Step 5: Calculate the context vector as the inner product between the
   1350       alignments and the attention_mechanism's values (memory).
   1351     - Step 6: Calculate the attention output by concatenating the cell output
   1352       and context through the attention layer (a linear layer with
   1353       `attention_layer_size` outputs).
   1354 
   1355     Args:
   1356       inputs: (Possibly nested tuple of) Tensor, the input at this time step.
   1357       state: An instance of `AttentionWrapperState` containing
   1358         tensors from the previous time step.
   1359 
   1360     Returns:
   1361       A tuple `(attention_or_cell_output, next_state)`, where:
   1362 
   1363       - `attention_or_cell_output` depending on `output_attention`.
   1364       - `next_state` is an instance of `AttentionWrapperState`
   1365          containing the state calculated at this time step.
   1366 
   1367     Raises:
   1368       TypeError: If `state` is not an instance of `AttentionWrapperState`.
   1369     """
   1370     if not isinstance(state, AttentionWrapperState):
   1371       raise TypeError("Expected state to be instance of AttentionWrapperState. "
   1372                       "Received type %s instead."  % type(state))
   1373 
   1374     # Step 1: Calculate the true inputs to the cell based on the
   1375     # previous attention value.
   1376     cell_inputs = self._cell_input_fn(inputs, state.attention)
   1377     cell_state = state.cell_state
   1378     cell_output, next_cell_state = self._cell(cell_inputs, cell_state)
   1379 
   1380     cell_batch_size = (
   1381         cell_output.shape[0].value or array_ops.shape(cell_output)[0])
   1382     error_message = (
   1383         "When applying AttentionWrapper %s: " % self.name +
   1384         "Non-matching batch sizes between the memory "
   1385         "(encoder output) and the query (decoder output).  Are you using "
   1386         "the BeamSearchDecoder?  You may need to tile your memory input via "
   1387         "the tf.contrib.seq2seq.tile_batch function with argument "
   1388         "multiple=beam_width.")
   1389     with ops.control_dependencies(
   1390         self._batch_size_checks(cell_batch_size, error_message)):
   1391       cell_output = array_ops.identity(
   1392           cell_output, name="checked_cell_output")
   1393 
   1394     if self._is_multi:
   1395       previous_attention_state = state.attention_state
   1396       previous_alignment_history = state.alignment_history
   1397     else:
   1398       previous_attention_state = [state.attention_state]
   1399       previous_alignment_history = [state.alignment_history]
   1400 
   1401     all_alignments = []
   1402     all_attentions = []
   1403     all_attention_states = []
   1404     maybe_all_histories = []
   1405     for i, attention_mechanism in enumerate(self._attention_mechanisms):
   1406       attention, alignments, next_attention_state = _compute_attention(
   1407           attention_mechanism, cell_output, previous_attention_state[i],
   1408           self._attention_layers[i] if self._attention_layers else None)
   1409       alignment_history = previous_alignment_history[i].write(
   1410           state.time, alignments) if self._alignment_history else ()
   1411 
   1412       all_attention_states.append(next_attention_state)
   1413       all_alignments.append(alignments)
   1414       all_attentions.append(attention)
   1415       maybe_all_histories.append(alignment_history)
   1416 
   1417     attention = array_ops.concat(all_attentions, 1)
   1418     next_state = AttentionWrapperState(
   1419         time=state.time + 1,
   1420         cell_state=next_cell_state,
   1421         attention=attention,
   1422         attention_state=self._item_or_tuple(all_attention_states),
   1423         alignments=self._item_or_tuple(all_alignments),
   1424         alignment_history=self._item_or_tuple(maybe_all_histories))
   1425 
   1426     if self._output_attention:
   1427       return attention, next_state
   1428     else:
   1429       return cell_output, next_state
   1430