Home | History | Annotate | Download | only in ops
      1 # Copyright 2016 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """A library of helpers for use with SamplingDecoders.
     16 """
     17 
     18 from __future__ import absolute_import
     19 from __future__ import division
     20 from __future__ import print_function
     21 
     22 import abc
     23 
     24 import six
     25 
     26 from tensorflow.contrib.seq2seq.python.ops import decoder
     27 from tensorflow.python.framework import dtypes
     28 from tensorflow.python.framework import ops
     29 from tensorflow.python.framework import tensor_shape
     30 from tensorflow.python.ops import array_ops
     31 from tensorflow.python.ops import control_flow_ops
     32 from tensorflow.python.ops import embedding_ops
     33 from tensorflow.python.ops import gen_array_ops
     34 from tensorflow.python.ops import math_ops
     35 from tensorflow.python.ops import tensor_array_ops
     36 from tensorflow.python.ops.distributions import bernoulli
     37 from tensorflow.python.ops.distributions import categorical
     38 from tensorflow.python.util import nest
     39 
     40 __all__ = [
     41     "Helper",
     42     "TrainingHelper",
     43     "GreedyEmbeddingHelper",
     44     "SampleEmbeddingHelper",
     45     "CustomHelper",
     46     "ScheduledEmbeddingTrainingHelper",
     47     "ScheduledOutputTrainingHelper",
     48     "InferenceHelper",
     49 ]
     50 
     51 _transpose_batch_time = decoder._transpose_batch_time  # pylint: disable=protected-access
     52 
     53 
     54 def _unstack_ta(inp):
     55   return tensor_array_ops.TensorArray(
     56       dtype=inp.dtype, size=array_ops.shape(inp)[0],
     57       element_shape=inp.get_shape()[1:]).unstack(inp)
     58 
     59 
     60 @six.add_metaclass(abc.ABCMeta)
     61 class Helper(object):
     62   """Interface for implementing sampling in seq2seq decoders.
     63 
     64   Helper instances are used by `BasicDecoder`.
     65   """
     66 
     67   @abc.abstractproperty
     68   def batch_size(self):
     69     """Batch size of tensor returned by `sample`.
     70 
     71     Returns a scalar int32 tensor.
     72     """
     73     raise NotImplementedError("batch_size has not been implemented")
     74 
     75   @abc.abstractproperty
     76   def sample_ids_shape(self):
     77     """Shape of tensor returned by `sample`, excluding the batch dimension.
     78 
     79     Returns a `TensorShape`.
     80     """
     81     raise NotImplementedError("sample_ids_shape has not been implemented")
     82 
     83   @abc.abstractproperty
     84   def sample_ids_dtype(self):
     85     """DType of tensor returned by `sample`.
     86 
     87     Returns a DType.
     88     """
     89     raise NotImplementedError("sample_ids_dtype has not been implemented")
     90 
     91   @abc.abstractmethod
     92   def initialize(self, name=None):
     93     """Returns `(initial_finished, initial_inputs)`."""
     94     pass
     95 
     96   @abc.abstractmethod
     97   def sample(self, time, outputs, state, name=None):
     98     """Returns `sample_ids`."""
     99     pass
    100 
    101   @abc.abstractmethod
    102   def next_inputs(self, time, outputs, state, sample_ids, name=None):
    103     """Returns `(finished, next_inputs, next_state)`."""
    104     pass
    105 
    106 
    107 class CustomHelper(Helper):
    108   """Base abstract class that allows the user to customize sampling."""
    109 
    110   def __init__(self, initialize_fn, sample_fn, next_inputs_fn,
    111                sample_ids_shape=None, sample_ids_dtype=None):
    112     """Initializer.
    113 
    114     Args:
    115       initialize_fn: callable that returns `(finished, next_inputs)`
    116         for the first iteration.
    117       sample_fn: callable that takes `(time, outputs, state)`
    118         and emits tensor `sample_ids`.
    119       next_inputs_fn: callable that takes `(time, outputs, state, sample_ids)`
    120         and emits `(finished, next_inputs, next_state)`.
    121       sample_ids_shape: Either a list of integers, or a 1-D Tensor of type
    122         `int32`, the shape of each value in the `sample_ids` batch. Defaults to
    123         a scalar.
    124       sample_ids_dtype: The dtype of the `sample_ids` tensor. Defaults to int32.
    125     """
    126     self._initialize_fn = initialize_fn
    127     self._sample_fn = sample_fn
    128     self._next_inputs_fn = next_inputs_fn
    129     self._batch_size = None
    130     self._sample_ids_shape = tensor_shape.TensorShape(sample_ids_shape or [])
    131     self._sample_ids_dtype = sample_ids_dtype or dtypes.int32
    132 
    133   @property
    134   def batch_size(self):
    135     if self._batch_size is None:
    136       raise ValueError("batch_size accessed before initialize was called")
    137     return self._batch_size
    138 
    139   @property
    140   def sample_ids_shape(self):
    141     return self._sample_ids_shape
    142 
    143   @property
    144   def sample_ids_dtype(self):
    145     return self._sample_ids_dtype
    146 
    147   def initialize(self, name=None):
    148     with ops.name_scope(name, "%sInitialize" % type(self).__name__):
    149       (finished, next_inputs) = self._initialize_fn()
    150       if self._batch_size is None:
    151         self._batch_size = array_ops.size(finished)
    152     return (finished, next_inputs)
    153 
    154   def sample(self, time, outputs, state, name=None):
    155     with ops.name_scope(
    156         name, "%sSample" % type(self).__name__, (time, outputs, state)):
    157       return self._sample_fn(time=time, outputs=outputs, state=state)
    158 
    159   def next_inputs(self, time, outputs, state, sample_ids, name=None):
    160     with ops.name_scope(
    161         name, "%sNextInputs" % type(self).__name__, (time, outputs, state)):
    162       return self._next_inputs_fn(
    163           time=time, outputs=outputs, state=state, sample_ids=sample_ids)
    164 
    165 
    166 class TrainingHelper(Helper):
    167   """A helper for use during training.  Only reads inputs.
    168 
    169   Returned sample_ids are the argmax of the RNN output logits.
    170   """
    171 
    172   def __init__(self, inputs, sequence_length, time_major=False, name=None):
    173     """Initializer.
    174 
    175     Args:
    176       inputs: A (structure of) input tensors.
    177       sequence_length: An int32 vector tensor.
    178       time_major: Python bool.  Whether the tensors in `inputs` are time major.
    179         If `False` (default), they are assumed to be batch major.
    180       name: Name scope for any created operations.
    181 
    182     Raises:
    183       ValueError: if `sequence_length` is not a 1D tensor.
    184     """
    185     with ops.name_scope(name, "TrainingHelper", [inputs, sequence_length]):
    186       inputs = ops.convert_to_tensor(inputs, name="inputs")
    187       self._inputs = inputs
    188       if not time_major:
    189         inputs = nest.map_structure(_transpose_batch_time, inputs)
    190 
    191       self._input_tas = nest.map_structure(_unstack_ta, inputs)
    192       self._sequence_length = ops.convert_to_tensor(
    193           sequence_length, name="sequence_length")
    194       if self._sequence_length.get_shape().ndims != 1:
    195         raise ValueError(
    196             "Expected sequence_length to be a vector, but received shape: %s" %
    197             self._sequence_length.get_shape())
    198 
    199       self._zero_inputs = nest.map_structure(
    200           lambda inp: array_ops.zeros_like(inp[0, :]), inputs)
    201 
    202       self._batch_size = array_ops.size(sequence_length)
    203 
    204   @property
    205   def inputs(self):
    206     return self._inputs
    207 
    208   @property
    209   def sequence_length(self):
    210     return self._sequence_length
    211 
    212   @property
    213   def batch_size(self):
    214     return self._batch_size
    215 
    216   @property
    217   def sample_ids_shape(self):
    218     return tensor_shape.TensorShape([])
    219 
    220   @property
    221   def sample_ids_dtype(self):
    222     return dtypes.int32
    223 
    224   def initialize(self, name=None):
    225     with ops.name_scope(name, "TrainingHelperInitialize"):
    226       finished = math_ops.equal(0, self._sequence_length)
    227       all_finished = math_ops.reduce_all(finished)
    228       next_inputs = control_flow_ops.cond(
    229           all_finished, lambda: self._zero_inputs,
    230           lambda: nest.map_structure(lambda inp: inp.read(0), self._input_tas))
    231       return (finished, next_inputs)
    232 
    233   def sample(self, time, outputs, name=None, **unused_kwargs):
    234     with ops.name_scope(name, "TrainingHelperSample", [time, outputs]):
    235       sample_ids = math_ops.cast(
    236           math_ops.argmax(outputs, axis=-1), dtypes.int32)
    237       return sample_ids
    238 
    239   def next_inputs(self, time, outputs, state, name=None, **unused_kwargs):
    240     """next_inputs_fn for TrainingHelper."""
    241     with ops.name_scope(name, "TrainingHelperNextInputs",
    242                         [time, outputs, state]):
    243       next_time = time + 1
    244       finished = (next_time >= self._sequence_length)
    245       all_finished = math_ops.reduce_all(finished)
    246       def read_from_ta(inp):
    247         return inp.read(next_time)
    248       next_inputs = control_flow_ops.cond(
    249           all_finished, lambda: self._zero_inputs,
    250           lambda: nest.map_structure(read_from_ta, self._input_tas))
    251       return (finished, next_inputs, state)
    252 
    253 
    254 class ScheduledEmbeddingTrainingHelper(TrainingHelper):
    255   """A training helper that adds scheduled sampling.
    256 
    257   Returns -1s for sample_ids where no sampling took place; valid sample id
    258   values elsewhere.
    259   """
    260 
    261   def __init__(self, inputs, sequence_length, embedding, sampling_probability,
    262                time_major=False, seed=None, scheduling_seed=None, name=None):
    263     """Initializer.
    264 
    265     Args:
    266       inputs: A (structure of) input tensors.
    267       sequence_length: An int32 vector tensor.
    268       embedding: A callable that takes a vector tensor of `ids` (argmax ids),
    269         or the `params` argument for `embedding_lookup`.
    270       sampling_probability: A 0D `float32` tensor: the probability of sampling
    271         categorically from the output ids instead of reading directly from the
    272         inputs.
    273       time_major: Python bool.  Whether the tensors in `inputs` are time major.
    274         If `False` (default), they are assumed to be batch major.
    275       seed: The sampling seed.
    276       scheduling_seed: The schedule decision rule sampling seed.
    277       name: Name scope for any created operations.
    278 
    279     Raises:
    280       ValueError: if `sampling_probability` is not a scalar or vector.
    281     """
    282     with ops.name_scope(name, "ScheduledEmbeddingSamplingWrapper",
    283                         [embedding, sampling_probability]):
    284       if callable(embedding):
    285         self._embedding_fn = embedding
    286       else:
    287         self._embedding_fn = (
    288             lambda ids: embedding_ops.embedding_lookup(embedding, ids))
    289       self._sampling_probability = ops.convert_to_tensor(
    290           sampling_probability, name="sampling_probability")
    291       if self._sampling_probability.get_shape().ndims not in (0, 1):
    292         raise ValueError(
    293             "sampling_probability must be either a scalar or a vector. "
    294             "saw shape: %s" % (self._sampling_probability.get_shape()))
    295       self._seed = seed
    296       self._scheduling_seed = scheduling_seed
    297       super(ScheduledEmbeddingTrainingHelper, self).__init__(
    298           inputs=inputs,
    299           sequence_length=sequence_length,
    300           time_major=time_major,
    301           name=name)
    302 
    303   def initialize(self, name=None):
    304     return super(ScheduledEmbeddingTrainingHelper, self).initialize(name=name)
    305 
    306   def sample(self, time, outputs, state, name=None):
    307     with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperSample",
    308                         [time, outputs, state]):
    309       # Return -1s where we did not sample, and sample_ids elsewhere
    310       select_sampler = bernoulli.Bernoulli(
    311           probs=self._sampling_probability, dtype=dtypes.bool)
    312       select_sample = select_sampler.sample(
    313           sample_shape=self.batch_size, seed=self._scheduling_seed)
    314       sample_id_sampler = categorical.Categorical(logits=outputs)
    315       return array_ops.where(
    316           select_sample,
    317           sample_id_sampler.sample(seed=self._seed),
    318           gen_array_ops.fill([self.batch_size], -1))
    319 
    320   def next_inputs(self, time, outputs, state, sample_ids, name=None):
    321     with ops.name_scope(name, "ScheduledEmbeddingTrainingHelperNextInputs",
    322                         [time, outputs, state, sample_ids]):
    323       (finished, base_next_inputs, state) = (
    324           super(ScheduledEmbeddingTrainingHelper, self).next_inputs(
    325               time=time,
    326               outputs=outputs,
    327               state=state,
    328               sample_ids=sample_ids,
    329               name=name))
    330 
    331       def maybe_sample():
    332         """Perform scheduled sampling."""
    333         where_sampling = math_ops.cast(
    334             array_ops.where(sample_ids > -1), dtypes.int32)
    335         where_not_sampling = math_ops.cast(
    336             array_ops.where(sample_ids <= -1), dtypes.int32)
    337         sample_ids_sampling = array_ops.gather_nd(sample_ids, where_sampling)
    338         inputs_not_sampling = array_ops.gather_nd(
    339             base_next_inputs, where_not_sampling)
    340         sampled_next_inputs = self._embedding_fn(sample_ids_sampling)
    341         base_shape = array_ops.shape(base_next_inputs)
    342         return (array_ops.scatter_nd(indices=where_sampling,
    343                                      updates=sampled_next_inputs,
    344                                      shape=base_shape)
    345                 + array_ops.scatter_nd(indices=where_not_sampling,
    346                                        updates=inputs_not_sampling,
    347                                        shape=base_shape))
    348 
    349       all_finished = math_ops.reduce_all(finished)
    350       next_inputs = control_flow_ops.cond(
    351           all_finished, lambda: base_next_inputs, maybe_sample)
    352       return (finished, next_inputs, state)
    353 
    354 
    355 class ScheduledOutputTrainingHelper(TrainingHelper):
    356   """A training helper that adds scheduled sampling directly to outputs.
    357 
    358   Returns False for sample_ids where no sampling took place; True elsewhere.
    359   """
    360 
    361   def __init__(self, inputs, sequence_length, sampling_probability,
    362                time_major=False, seed=None, next_inputs_fn=None,
    363                auxiliary_inputs=None, name=None):
    364     """Initializer.
    365 
    366     Args:
    367       inputs: A (structure) of input tensors.
    368       sequence_length: An int32 vector tensor.
    369       sampling_probability: A 0D `float32` tensor: the probability of sampling
    370         from the outputs instead of reading directly from the inputs.
    371       time_major: Python bool.  Whether the tensors in `inputs` are time major.
    372         If `False` (default), they are assumed to be batch major.
    373       seed: The sampling seed.
    374       next_inputs_fn: (Optional) callable to apply to the RNN outputs to create
    375         the next input when sampling. If `None` (default), the RNN outputs will
    376         be used as the next inputs.
    377       auxiliary_inputs: An optional (structure of) auxiliary input tensors with
    378         a shape that matches `inputs` in all but (potentially) the final
    379         dimension. These tensors will be concatenated to the sampled output or
    380         the `inputs` when not sampling for use as the next input.
    381       name: Name scope for any created operations.
    382 
    383     Raises:
    384       ValueError: if `sampling_probability` is not a scalar or vector.
    385     """
    386     with ops.name_scope(name, "ScheduledOutputTrainingHelper",
    387                         [inputs, auxiliary_inputs, sampling_probability]):
    388       self._sampling_probability = ops.convert_to_tensor(
    389           sampling_probability, name="sampling_probability")
    390       if self._sampling_probability.get_shape().ndims not in (0, 1):
    391         raise ValueError(
    392             "sampling_probability must be either a scalar or a vector. "
    393             "saw shape: %s" % (self._sampling_probability.get_shape()))
    394 
    395       if auxiliary_inputs is None:
    396         maybe_concatenated_inputs = inputs
    397       else:
    398         inputs = ops.convert_to_tensor(inputs, name="inputs")
    399         auxiliary_inputs = ops.convert_to_tensor(
    400             auxiliary_inputs, name="auxiliary_inputs")
    401         maybe_concatenated_inputs = nest.map_structure(
    402             lambda x, y: array_ops.concat((x, y), -1),
    403             inputs, auxiliary_inputs)
    404         if not time_major:
    405           auxiliary_inputs = nest.map_structure(
    406               _transpose_batch_time, auxiliary_inputs)
    407 
    408       self._auxiliary_input_tas = (
    409           nest.map_structure(_unstack_ta, auxiliary_inputs)
    410           if auxiliary_inputs is not None else None)
    411 
    412       self._seed = seed
    413 
    414       self._next_inputs_fn = next_inputs_fn
    415 
    416       super(ScheduledOutputTrainingHelper, self).__init__(
    417           inputs=maybe_concatenated_inputs,
    418           sequence_length=sequence_length,
    419           time_major=time_major,
    420           name=name)
    421 
    422   def initialize(self, name=None):
    423     return super(ScheduledOutputTrainingHelper, self).initialize(name=name)
    424 
    425   def sample(self, time, outputs, state, name=None):
    426     with ops.name_scope(name, "ScheduledOutputTrainingHelperSample",
    427                         [time, outputs, state]):
    428       sampler = bernoulli.Bernoulli(probs=self._sampling_probability)
    429       return sampler.sample(sample_shape=self.batch_size, seed=self._seed)
    430 
    431   def next_inputs(self, time, outputs, state, sample_ids, name=None):
    432     with ops.name_scope(name, "ScheduledOutputTrainingHelperNextInputs",
    433                         [time, outputs, state, sample_ids]):
    434       (finished, base_next_inputs, state) = (
    435           super(ScheduledOutputTrainingHelper, self).next_inputs(
    436               time=time,
    437               outputs=outputs,
    438               state=state,
    439               sample_ids=sample_ids,
    440               name=name))
    441       sample_ids = math_ops.cast(sample_ids, dtypes.bool)
    442 
    443       def maybe_sample():
    444         """Perform scheduled sampling."""
    445 
    446         def maybe_concatenate_auxiliary_inputs(outputs_, indices=None):
    447           """Concatenate outputs with auxiliary inputs, if they exist."""
    448           if self._auxiliary_input_tas is None:
    449             return outputs_
    450 
    451           next_time = time + 1
    452           auxiliary_inputs = nest.map_structure(
    453               lambda ta: ta.read(next_time), self._auxiliary_input_tas)
    454           if indices is not None:
    455             auxiliary_inputs = array_ops.gather_nd(auxiliary_inputs, indices)
    456           return nest.map_structure(
    457               lambda x, y: array_ops.concat((x, y), -1),
    458               outputs_, auxiliary_inputs)
    459 
    460         if self._next_inputs_fn is None:
    461           return array_ops.where(
    462               sample_ids, maybe_concatenate_auxiliary_inputs(outputs),
    463               base_next_inputs)
    464 
    465         where_sampling = math_ops.cast(
    466             array_ops.where(sample_ids), dtypes.int32)
    467         where_not_sampling = math_ops.cast(
    468             array_ops.where(math_ops.logical_not(sample_ids)), dtypes.int32)
    469         outputs_sampling = array_ops.gather_nd(outputs, where_sampling)
    470         inputs_not_sampling = array_ops.gather_nd(base_next_inputs,
    471                                                   where_not_sampling)
    472         sampled_next_inputs = maybe_concatenate_auxiliary_inputs(
    473             self._next_inputs_fn(outputs_sampling), where_sampling)
    474 
    475         base_shape = array_ops.shape(base_next_inputs)
    476         return (array_ops.scatter_nd(indices=where_sampling,
    477                                      updates=sampled_next_inputs,
    478                                      shape=base_shape)
    479                 + array_ops.scatter_nd(indices=where_not_sampling,
    480                                        updates=inputs_not_sampling,
    481                                        shape=base_shape))
    482 
    483       all_finished = math_ops.reduce_all(finished)
    484       no_samples = math_ops.logical_not(math_ops.reduce_any(sample_ids))
    485       next_inputs = control_flow_ops.cond(
    486           math_ops.logical_or(all_finished, no_samples),
    487           lambda: base_next_inputs, maybe_sample)
    488       return (finished, next_inputs, state)
    489 
    490 
    491 class GreedyEmbeddingHelper(Helper):
    492   """A helper for use during inference.
    493 
    494   Uses the argmax of the output (treated as logits) and passes the
    495   result through an embedding layer to get the next input.
    496   """
    497 
    498   def __init__(self, embedding, start_tokens, end_token):
    499     """Initializer.
    500 
    501     Args:
    502       embedding: A callable that takes a vector tensor of `ids` (argmax ids),
    503         or the `params` argument for `embedding_lookup`. The returned tensor
    504         will be passed to the decoder input.
    505       start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
    506       end_token: `int32` scalar, the token that marks end of decoding.
    507 
    508     Raises:
    509       ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
    510         scalar.
    511     """
    512     if callable(embedding):
    513       self._embedding_fn = embedding
    514     else:
    515       self._embedding_fn = (
    516           lambda ids: embedding_ops.embedding_lookup(embedding, ids))
    517 
    518     self._start_tokens = ops.convert_to_tensor(
    519         start_tokens, dtype=dtypes.int32, name="start_tokens")
    520     self._end_token = ops.convert_to_tensor(
    521         end_token, dtype=dtypes.int32, name="end_token")
    522     if self._start_tokens.get_shape().ndims != 1:
    523       raise ValueError("start_tokens must be a vector")
    524     self._batch_size = array_ops.size(start_tokens)
    525     if self._end_token.get_shape().ndims != 0:
    526       raise ValueError("end_token must be a scalar")
    527     self._start_inputs = self._embedding_fn(self._start_tokens)
    528 
    529   @property
    530   def batch_size(self):
    531     return self._batch_size
    532 
    533   @property
    534   def sample_ids_shape(self):
    535     return tensor_shape.TensorShape([])
    536 
    537   @property
    538   def sample_ids_dtype(self):
    539     return dtypes.int32
    540 
    541   def initialize(self, name=None):
    542     finished = array_ops.tile([False], [self._batch_size])
    543     return (finished, self._start_inputs)
    544 
    545   def sample(self, time, outputs, state, name=None):
    546     """sample for GreedyEmbeddingHelper."""
    547     del time, state  # unused by sample_fn
    548     # Outputs are logits, use argmax to get the most probable id
    549     if not isinstance(outputs, ops.Tensor):
    550       raise TypeError("Expected outputs to be a single Tensor, got: %s" %
    551                       type(outputs))
    552     sample_ids = math_ops.argmax(outputs, axis=-1, output_type=dtypes.int32)
    553     return sample_ids
    554 
    555   def next_inputs(self, time, outputs, state, sample_ids, name=None):
    556     """next_inputs_fn for GreedyEmbeddingHelper."""
    557     del time, outputs  # unused by next_inputs_fn
    558     finished = math_ops.equal(sample_ids, self._end_token)
    559     all_finished = math_ops.reduce_all(finished)
    560     next_inputs = control_flow_ops.cond(
    561         all_finished,
    562         # If we're finished, the next_inputs value doesn't matter
    563         lambda: self._start_inputs,
    564         lambda: self._embedding_fn(sample_ids))
    565     return (finished, next_inputs, state)
    566 
    567 
    568 class SampleEmbeddingHelper(GreedyEmbeddingHelper):
    569   """A helper for use during inference.
    570 
    571   Uses sampling (from a distribution) instead of argmax and passes the
    572   result through an embedding layer to get the next input.
    573   """
    574 
    575   def __init__(self, embedding, start_tokens, end_token,
    576                softmax_temperature=None, seed=None):
    577     """Initializer.
    578 
    579     Args:
    580       embedding: A callable that takes a vector tensor of `ids` (argmax ids),
    581         or the `params` argument for `embedding_lookup`. The returned tensor
    582         will be passed to the decoder input.
    583       start_tokens: `int32` vector shaped `[batch_size]`, the start tokens.
    584       end_token: `int32` scalar, the token that marks end of decoding.
    585       softmax_temperature: (Optional) `float32` scalar, value to divide the
    586         logits by before computing the softmax. Larger values (above 1.0) result
    587         in more random samples, while smaller values push the sampling
    588         distribution towards the argmax. Must be strictly greater than 0.
    589         Defaults to 1.0.
    590       seed: (Optional) The sampling seed.
    591 
    592     Raises:
    593       ValueError: if `start_tokens` is not a 1D tensor or `end_token` is not a
    594         scalar.
    595     """
    596     super(SampleEmbeddingHelper, self).__init__(
    597         embedding, start_tokens, end_token)
    598     self._softmax_temperature = softmax_temperature
    599     self._seed = seed
    600 
    601   def sample(self, time, outputs, state, name=None):
    602     """sample for SampleEmbeddingHelper."""
    603     del time, state  # unused by sample_fn
    604     # Outputs are logits, we sample instead of argmax (greedy).
    605     if not isinstance(outputs, ops.Tensor):
    606       raise TypeError("Expected outputs to be a single Tensor, got: %s" %
    607                       type(outputs))
    608     if self._softmax_temperature is None:
    609       logits = outputs
    610     else:
    611       logits = outputs / self._softmax_temperature
    612 
    613     sample_id_sampler = categorical.Categorical(logits=logits)
    614     sample_ids = sample_id_sampler.sample(seed=self._seed)
    615 
    616     return sample_ids
    617 
    618 
    619 class InferenceHelper(Helper):
    620   """A helper to use during inference with a custom sampling function."""
    621 
    622   def __init__(self, sample_fn, sample_shape, sample_dtype,
    623                start_inputs, end_fn, next_inputs_fn=None):
    624     """Initializer.
    625 
    626     Args:
    627       sample_fn: A callable that takes `outputs` and emits tensor `sample_ids`.
    628       sample_shape: Either a list of integers, or a 1-D Tensor of type `int32`,
    629         the shape of the each sample in the batch returned by `sample_fn`.
    630       sample_dtype: the dtype of the sample returned by `sample_fn`.
    631       start_inputs: The initial batch of inputs.
    632       end_fn: A callable that takes `sample_ids` and emits a `bool` vector
    633         shaped `[batch_size]` indicating whether each sample is an end token.
    634       next_inputs_fn: (Optional) A callable that takes `sample_ids` and returns
    635         the next batch of inputs. If not provided, `sample_ids` is used as the
    636         next batch of inputs.
    637     """
    638     self._sample_fn = sample_fn
    639     self._end_fn = end_fn
    640     self._sample_shape = tensor_shape.TensorShape(sample_shape)
    641     self._sample_dtype = sample_dtype
    642     self._next_inputs_fn = next_inputs_fn
    643     self._batch_size = array_ops.shape(start_inputs)[0]
    644     self._start_inputs = ops.convert_to_tensor(
    645         start_inputs, name="start_inputs")
    646 
    647   @property
    648   def batch_size(self):
    649     return self._batch_size
    650 
    651   @property
    652   def sample_ids_shape(self):
    653     return self._sample_shape
    654 
    655   @property
    656   def sample_ids_dtype(self):
    657     return self._sample_dtype
    658 
    659   def initialize(self, name=None):
    660     finished = array_ops.tile([False], [self._batch_size])
    661     return (finished, self._start_inputs)
    662 
    663   def sample(self, time, outputs, state, name=None):
    664     del time, state  # unused by sample
    665     return self._sample_fn(outputs)
    666 
    667   def next_inputs(self, time, outputs, state, sample_ids, name=None):
    668     del time, outputs  # unused by next_inputs
    669     if self._next_inputs_fn is None:
    670       next_inputs = sample_ids
    671     else:
    672       next_inputs = self._next_inputs_fn(sample_ids)
    673     finished = self._end_fn(sample_ids)
    674     return (finished, next_inputs, state)
    675