Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for contrib.seq2seq.python.seq2seq.beam_search_decoder."""
     16 # pylint: disable=unused-import,g-bad-import-order
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 # pylint: enable=unused-import
     21 
     22 import numpy as np
     23 
     24 from tensorflow.contrib.seq2seq.python.ops import attention_wrapper
     25 from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder
     26 from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
     27 from tensorflow.contrib.seq2seq.python.ops import decoder
     28 from tensorflow.python.framework import constant_op
     29 from tensorflow.python.framework import dtypes
     30 from tensorflow.python.framework import ops
     31 from tensorflow.python.layers import core as layers_core
     32 from tensorflow.python.ops import array_ops
     33 from tensorflow.python.ops import nn_ops
     34 from tensorflow.python.ops import rnn_cell
     35 from tensorflow.python.ops import variables
     36 from tensorflow.python.platform import test
     37 
     38 # pylint: enable=g-import-not-at-top
     39 
     40 
     41 class TestGatherTree(test.TestCase):
     42   """Tests the gather_tree function."""
     43 
     44   def test_gather_tree(self):
     45     # (max_time = 3, batch_size = 2, beam_width = 3)
     46 
     47     # create (batch_size, max_time, beam_width) matrix and transpose it
     48     predicted_ids = np.array(
     49         [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]],
     50         dtype=np.int32).transpose([1, 0, 2])
     51     parent_ids = np.array(
     52         [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]],
     53         dtype=np.int32).transpose([1, 0, 2])
     54 
     55     # sequence_lengths is shaped (batch_size = 3)
     56     max_sequence_lengths = [3, 3]
     57 
     58     expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]],
     59                                 [[2, 4, 4], [7, 6, 6],
     60                                  [8, 9, 10]]]).transpose([1, 0, 2])
     61 
     62     res = beam_search_ops.gather_tree(
     63         predicted_ids,
     64         parent_ids,
     65         max_sequence_lengths=max_sequence_lengths,
     66         end_token=11)
     67 
     68     with self.test_session() as sess:
     69       res_ = sess.run(res)
     70 
     71     self.assertAllEqual(expected_result, res_)
     72 
     73 
     74 class TestEosMasking(test.TestCase):
     75   """Tests EOS masking used in beam search."""
     76 
     77   def test_eos_masking(self):
     78     probs = constant_op.constant([
     79         [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]],
     80         [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]],
     81     ])
     82 
     83     eos_token = 0
     84     previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool)
     85     masked = beam_search_decoder._mask_probs(probs, eos_token,
     86                                              previously_finished)
     87 
     88     with self.test_session() as sess:
     89       probs = sess.run(probs)
     90       masked = sess.run(masked)
     91 
     92       self.assertAllEqual(probs[0][0], masked[0][0])
     93       self.assertAllEqual(probs[0][2], masked[0][2])
     94       self.assertAllEqual(probs[1][0], masked[1][0])
     95 
     96       self.assertEqual(masked[0][1][0], 0)
     97       self.assertEqual(masked[1][1][0], 0)
     98       self.assertEqual(masked[1][2][0], 0)
     99 
    100       for i in range(1, 5):
    101         self.assertAllClose(masked[0][1][i], np.finfo('float32').min)
    102         self.assertAllClose(masked[1][1][i], np.finfo('float32').min)
    103         self.assertAllClose(masked[1][2][i], np.finfo('float32').min)
    104 
    105 
    106 class TestBeamStep(test.TestCase):
    107   """Tests a single step of beam search."""
    108 
    109   def setUp(self):
    110     super(TestBeamStep, self).setUp()
    111     self.batch_size = 2
    112     self.beam_width = 3
    113     self.vocab_size = 5
    114     self.end_token = 0
    115     self.length_penalty_weight = 0.6
    116 
    117   def test_step(self):
    118     dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
    119     beam_state = beam_search_decoder.BeamSearchDecoderState(
    120         cell_state=dummy_cell_state,
    121         log_probs=nn_ops.log_softmax(
    122             array_ops.ones([self.batch_size, self.beam_width])),
    123         lengths=constant_op.constant(
    124             2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64),
    125         finished=array_ops.zeros(
    126             [self.batch_size, self.beam_width], dtype=dtypes.bool))
    127 
    128     logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
    129                       0.0001)
    130     logits_[0, 0, 2] = 1.9
    131     logits_[0, 0, 3] = 2.1
    132     logits_[0, 1, 3] = 3.1
    133     logits_[0, 1, 4] = 0.9
    134     logits_[1, 0, 1] = 0.5
    135     logits_[1, 1, 2] = 2.7
    136     logits_[1, 2, 2] = 10.0
    137     logits_[1, 2, 3] = 0.2
    138     logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
    139     log_probs = nn_ops.log_softmax(logits)
    140 
    141     outputs, next_beam_state = beam_search_decoder._beam_search_step(
    142         time=2,
    143         logits=logits,
    144         next_cell_state=dummy_cell_state,
    145         beam_state=beam_state,
    146         batch_size=ops.convert_to_tensor(self.batch_size),
    147         beam_width=self.beam_width,
    148         end_token=self.end_token,
    149         length_penalty_weight=self.length_penalty_weight)
    150 
    151     with self.test_session() as sess:
    152       outputs_, next_state_, state_, log_probs_ = sess.run(
    153           [outputs, next_beam_state, beam_state, log_probs])
    154 
    155     self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]])
    156     self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]])
    157     self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]])
    158     self.assertAllEqual(next_state_.finished,
    159                         [[False, False, False], [False, False, False]])
    160 
    161     expected_log_probs = []
    162     expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
    163     expected_log_probs.append(state_.log_probs[1][[2, 1, 0]])  # 0 --> 1
    164     expected_log_probs[0][0] += log_probs_[0, 1, 3]
    165     expected_log_probs[0][1] += log_probs_[0, 0, 3]
    166     expected_log_probs[0][2] += log_probs_[0, 0, 2]
    167     expected_log_probs[1][0] += log_probs_[1, 2, 2]
    168     expected_log_probs[1][1] += log_probs_[1, 1, 2]
    169     expected_log_probs[1][2] += log_probs_[1, 0, 1]
    170     self.assertAllEqual(next_state_.log_probs, expected_log_probs)
    171 
    172   def test_step_with_eos(self):
    173     dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
    174     beam_state = beam_search_decoder.BeamSearchDecoderState(
    175         cell_state=dummy_cell_state,
    176         log_probs=nn_ops.log_softmax(
    177             array_ops.ones([self.batch_size, self.beam_width])),
    178         lengths=ops.convert_to_tensor(
    179             [[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64),
    180         finished=ops.convert_to_tensor(
    181             [[False, True, False], [False, False, True]], dtype=dtypes.bool))
    182 
    183     logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
    184                       0.0001)
    185     logits_[0, 0, 2] = 1.9
    186     logits_[0, 0, 3] = 2.1
    187     logits_[0, 1, 3] = 3.1
    188     logits_[0, 1, 4] = 0.9
    189     logits_[1, 0, 1] = 0.5
    190     logits_[1, 1, 2] = 5.7  # why does this not work when it's 2.7?
    191     logits_[1, 2, 2] = 1.0
    192     logits_[1, 2, 3] = 0.2
    193     logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32)
    194     log_probs = nn_ops.log_softmax(logits)
    195 
    196     outputs, next_beam_state = beam_search_decoder._beam_search_step(
    197         time=2,
    198         logits=logits,
    199         next_cell_state=dummy_cell_state,
    200         beam_state=beam_state,
    201         batch_size=ops.convert_to_tensor(self.batch_size),
    202         beam_width=self.beam_width,
    203         end_token=self.end_token,
    204         length_penalty_weight=self.length_penalty_weight)
    205 
    206     with self.test_session() as sess:
    207       outputs_, next_state_, state_, log_probs_ = sess.run(
    208           [outputs, next_beam_state, beam_state, log_probs])
    209 
    210     self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]])
    211     self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]])
    212     self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]])
    213     self.assertAllEqual(next_state_.finished,
    214                         [[True, False, False], [False, True, False]])
    215 
    216     expected_log_probs = []
    217     expected_log_probs.append(state_.log_probs[0][[1, 0, 0]])
    218     expected_log_probs.append(state_.log_probs[1][[1, 2, 0]])
    219     expected_log_probs[0][1] += log_probs_[0, 0, 3]
    220     expected_log_probs[0][2] += log_probs_[0, 0, 2]
    221     expected_log_probs[1][0] += log_probs_[1, 1, 2]
    222     expected_log_probs[1][2] += log_probs_[1, 0, 1]
    223     self.assertAllEqual(next_state_.log_probs, expected_log_probs)
    224 
    225 
    226 class TestLargeBeamStep(test.TestCase):
    227   """Tests large beam step.
    228 
    229   Tests a single step of beam search in such case that beam size is larger than
    230   vocabulary size.
    231   """
    232 
    233   def setUp(self):
    234     super(TestLargeBeamStep, self).setUp()
    235     self.batch_size = 2
    236     self.beam_width = 8
    237     self.vocab_size = 5
    238     self.end_token = 0
    239     self.length_penalty_weight = 0.6
    240 
    241   def test_step(self):
    242 
    243     def get_probs():
    244       """this simulates the initialize method in BeamSearchDecoder."""
    245       log_prob_mask = array_ops.one_hot(
    246           array_ops.zeros([self.batch_size], dtype=dtypes.int32),
    247           depth=self.beam_width,
    248           on_value=True,
    249           off_value=False,
    250           dtype=dtypes.bool)
    251 
    252       log_prob_zeros = array_ops.zeros(
    253           [self.batch_size, self.beam_width], dtype=dtypes.float32)
    254       log_prob_neg_inf = array_ops.ones(
    255           [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf
    256 
    257       log_probs = array_ops.where(log_prob_mask, log_prob_zeros,
    258                                   log_prob_neg_inf)
    259       return log_probs
    260 
    261     log_probs = get_probs()
    262     dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width])
    263 
    264     # pylint: disable=invalid-name
    265     _finished = array_ops.one_hot(
    266         array_ops.zeros([self.batch_size], dtype=dtypes.int32),
    267         depth=self.beam_width,
    268         on_value=False,
    269         off_value=True,
    270         dtype=dtypes.bool)
    271     _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64)
    272     _lengths[:, 0] = 2
    273     _lengths = constant_op.constant(_lengths, dtype=dtypes.int64)
    274 
    275     beam_state = beam_search_decoder.BeamSearchDecoderState(
    276         cell_state=dummy_cell_state,
    277         log_probs=log_probs,
    278         lengths=_lengths,
    279         finished=_finished)
    280 
    281     logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size],
    282                       0.0001)
    283     logits_[0, 0, 2] = 1.9
    284     logits_[0, 0, 3] = 2.1
    285     logits_[0, 1, 3] = 3.1
    286     logits_[0, 1, 4] = 0.9
    287     logits_[1, 0, 1] = 0.5
    288     logits_[1, 1, 2] = 2.7
    289     logits_[1, 2, 2] = 10.0
    290     logits_[1, 2, 3] = 0.2
    291     logits = constant_op.constant(logits_, dtype=dtypes.float32)
    292     log_probs = nn_ops.log_softmax(logits)
    293 
    294     outputs, next_beam_state = beam_search_decoder._beam_search_step(
    295         time=2,
    296         logits=logits,
    297         next_cell_state=dummy_cell_state,
    298         beam_state=beam_state,
    299         batch_size=ops.convert_to_tensor(self.batch_size),
    300         beam_width=self.beam_width,
    301         end_token=self.end_token,
    302         length_penalty_weight=self.length_penalty_weight)
    303 
    304     with self.test_session() as sess:
    305       outputs_, next_state_, _, _ = sess.run(
    306           [outputs, next_beam_state, beam_state, log_probs])
    307 
    308     self.assertEqual(outputs_.predicted_ids[0, 0], 3)
    309     self.assertEqual(outputs_.predicted_ids[0, 1], 2)
    310     self.assertEqual(outputs_.predicted_ids[1, 0], 1)
    311     neg_inf = -np.Inf
    312     self.assertAllEqual(
    313         next_state_.log_probs[:, -3:],
    314         [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]])
    315     self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True)
    316     self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True)
    317     self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]])
    318 
    319 
    320 class BeamSearchDecoderTest(test.TestCase):
    321 
    322   def _testDynamicDecodeRNN(self, time_major, has_attention):
    323     encoder_sequence_length = np.array([3, 2, 3, 1, 1])
    324     decoder_sequence_length = np.array([2, 0, 1, 2, 3])
    325     batch_size = 5
    326     decoder_max_time = 4
    327     input_depth = 7
    328     cell_depth = 9
    329     attention_depth = 6
    330     vocab_size = 20
    331     end_token = vocab_size - 1
    332     start_token = 0
    333     embedding_dim = 50
    334     max_out = max(decoder_sequence_length)
    335     output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None)
    336     beam_width = 3
    337 
    338     with self.test_session() as sess:
    339       batch_size_tensor = constant_op.constant(batch_size)
    340       embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32)
    341       cell = rnn_cell.LSTMCell(cell_depth)
    342       initial_state = cell.zero_state(batch_size, dtypes.float32)
    343       if has_attention:
    344         inputs = array_ops.placeholder_with_default(
    345             np.random.randn(batch_size, decoder_max_time, input_depth).astype(
    346                 np.float32),
    347             shape=(None, None, input_depth))
    348         tiled_inputs = beam_search_decoder.tile_batch(
    349             inputs, multiplier=beam_width)
    350         tiled_sequence_length = beam_search_decoder.tile_batch(
    351             encoder_sequence_length, multiplier=beam_width)
    352         attention_mechanism = attention_wrapper.BahdanauAttention(
    353             num_units=attention_depth,
    354             memory=tiled_inputs,
    355             memory_sequence_length=tiled_sequence_length)
    356         initial_state = beam_search_decoder.tile_batch(
    357             initial_state, multiplier=beam_width)
    358         cell = attention_wrapper.AttentionWrapper(
    359             cell=cell,
    360             attention_mechanism=attention_mechanism,
    361             attention_layer_size=attention_depth,
    362             alignment_history=False)
    363       cell_state = cell.zero_state(
    364           dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width)
    365       if has_attention:
    366         cell_state = cell_state.clone(cell_state=initial_state)
    367       bsd = beam_search_decoder.BeamSearchDecoder(
    368           cell=cell,
    369           embedding=embedding,
    370           start_tokens=array_ops.fill([batch_size_tensor], start_token),
    371           end_token=end_token,
    372           initial_state=cell_state,
    373           beam_width=beam_width,
    374           output_layer=output_layer,
    375           length_penalty_weight=0.0)
    376 
    377       final_outputs, final_state, final_sequence_lengths = (
    378           decoder.dynamic_decode(
    379               bsd, output_time_major=time_major, maximum_iterations=max_out))
    380 
    381       def _t(shape):
    382         if time_major:
    383           return (shape[1], shape[0]) + shape[2:]
    384         return shape
    385 
    386       self.assertTrue(
    387           isinstance(final_outputs,
    388                      beam_search_decoder.FinalBeamSearchDecoderOutput))
    389       self.assertTrue(
    390           isinstance(final_state, beam_search_decoder.BeamSearchDecoderState))
    391 
    392       beam_search_decoder_output = final_outputs.beam_search_decoder_output
    393       self.assertEqual(
    394           _t((batch_size, None, beam_width)),
    395           tuple(beam_search_decoder_output.scores.get_shape().as_list()))
    396       self.assertEqual(
    397           _t((batch_size, None, beam_width)),
    398           tuple(final_outputs.predicted_ids.get_shape().as_list()))
    399 
    400       sess.run(variables.global_variables_initializer())
    401       sess_results = sess.run({
    402           'final_outputs': final_outputs,
    403           'final_state': final_state,
    404           'final_sequence_lengths': final_sequence_lengths
    405       })
    406 
    407       max_sequence_length = np.max(sess_results['final_sequence_lengths'])
    408 
    409       # A smoke test
    410       self.assertEqual(
    411           _t((batch_size, max_sequence_length, beam_width)),
    412           sess_results['final_outputs'].beam_search_decoder_output.scores.shape)
    413       self.assertEqual(
    414           _t((batch_size, max_sequence_length, beam_width)), sess_results[
    415               'final_outputs'].beam_search_decoder_output.predicted_ids.shape)
    416 
    417   def testDynamicDecodeRNNBatchMajorNoAttention(self):
    418     self._testDynamicDecodeRNN(time_major=False, has_attention=False)
    419 
    420   def testDynamicDecodeRNNBatchMajorYesAttention(self):
    421     self._testDynamicDecodeRNN(time_major=False, has_attention=True)
    422 
    423 
    424 if __name__ == '__main__':
    425   test.main()
    426