1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for contrib.seq2seq.python.seq2seq.beam_search_decoder.""" 16 # pylint: disable=unused-import,g-bad-import-order 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 # pylint: enable=unused-import 21 22 import numpy as np 23 24 from tensorflow.contrib.seq2seq.python.ops import attention_wrapper 25 from tensorflow.contrib.seq2seq.python.ops import beam_search_decoder 26 from tensorflow.contrib.seq2seq.python.ops import beam_search_ops 27 from tensorflow.contrib.seq2seq.python.ops import decoder 28 from tensorflow.python.framework import constant_op 29 from tensorflow.python.framework import dtypes 30 from tensorflow.python.framework import ops 31 from tensorflow.python.layers import core as layers_core 32 from tensorflow.python.ops import array_ops 33 from tensorflow.python.ops import nn_ops 34 from tensorflow.python.ops import rnn_cell 35 from tensorflow.python.ops import variables 36 from tensorflow.python.platform import test 37 38 # pylint: enable=g-import-not-at-top 39 40 41 class TestGatherTree(test.TestCase): 42 """Tests the gather_tree function.""" 43 44 def test_gather_tree(self): 45 # (max_time = 3, batch_size = 2, beam_width = 3) 46 47 # create (batch_size, max_time, beam_width) matrix and transpose it 48 predicted_ids = np.array( 49 [[[1, 2, 3], [4, 5, 6], [7, 8, 9]], [[2, 3, 4], [5, 6, 7], [8, 9, 10]]], 50 dtype=np.int32).transpose([1, 0, 2]) 51 parent_ids = np.array( 52 [[[0, 0, 0], [0, 1, 1], [2, 1, 2]], [[0, 0, 0], [1, 2, 0], [2, 1, 1]]], 53 dtype=np.int32).transpose([1, 0, 2]) 54 55 # sequence_lengths is shaped (batch_size = 3) 56 max_sequence_lengths = [3, 3] 57 58 expected_result = np.array([[[2, 2, 2], [6, 5, 6], [7, 8, 9]], 59 [[2, 4, 4], [7, 6, 6], 60 [8, 9, 10]]]).transpose([1, 0, 2]) 61 62 res = beam_search_ops.gather_tree( 63 predicted_ids, 64 parent_ids, 65 max_sequence_lengths=max_sequence_lengths, 66 end_token=11) 67 68 with self.test_session() as sess: 69 res_ = sess.run(res) 70 71 self.assertAllEqual(expected_result, res_) 72 73 74 class TestEosMasking(test.TestCase): 75 """Tests EOS masking used in beam search.""" 76 77 def test_eos_masking(self): 78 probs = constant_op.constant([ 79 [[-.2, -.2, -.2, -.2, -.2], [-.3, -.3, -.3, 3, 0], [5, 6, 0, 0, 0]], 80 [[-.2, -.2, -.2, -.2, 0], [-.3, -.3, -.1, 3, 0], [5, 6, 3, 0, 0]], 81 ]) 82 83 eos_token = 0 84 previously_finished = np.array([[0, 1, 0], [0, 1, 1]], dtype=bool) 85 masked = beam_search_decoder._mask_probs(probs, eos_token, 86 previously_finished) 87 88 with self.test_session() as sess: 89 probs = sess.run(probs) 90 masked = sess.run(masked) 91 92 self.assertAllEqual(probs[0][0], masked[0][0]) 93 self.assertAllEqual(probs[0][2], masked[0][2]) 94 self.assertAllEqual(probs[1][0], masked[1][0]) 95 96 self.assertEqual(masked[0][1][0], 0) 97 self.assertEqual(masked[1][1][0], 0) 98 self.assertEqual(masked[1][2][0], 0) 99 100 for i in range(1, 5): 101 self.assertAllClose(masked[0][1][i], np.finfo('float32').min) 102 self.assertAllClose(masked[1][1][i], np.finfo('float32').min) 103 self.assertAllClose(masked[1][2][i], np.finfo('float32').min) 104 105 106 class TestBeamStep(test.TestCase): 107 """Tests a single step of beam search.""" 108 109 def setUp(self): 110 super(TestBeamStep, self).setUp() 111 self.batch_size = 2 112 self.beam_width = 3 113 self.vocab_size = 5 114 self.end_token = 0 115 self.length_penalty_weight = 0.6 116 117 def test_step(self): 118 dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) 119 beam_state = beam_search_decoder.BeamSearchDecoderState( 120 cell_state=dummy_cell_state, 121 log_probs=nn_ops.log_softmax( 122 array_ops.ones([self.batch_size, self.beam_width])), 123 lengths=constant_op.constant( 124 2, shape=[self.batch_size, self.beam_width], dtype=dtypes.int64), 125 finished=array_ops.zeros( 126 [self.batch_size, self.beam_width], dtype=dtypes.bool)) 127 128 logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 129 0.0001) 130 logits_[0, 0, 2] = 1.9 131 logits_[0, 0, 3] = 2.1 132 logits_[0, 1, 3] = 3.1 133 logits_[0, 1, 4] = 0.9 134 logits_[1, 0, 1] = 0.5 135 logits_[1, 1, 2] = 2.7 136 logits_[1, 2, 2] = 10.0 137 logits_[1, 2, 3] = 0.2 138 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) 139 log_probs = nn_ops.log_softmax(logits) 140 141 outputs, next_beam_state = beam_search_decoder._beam_search_step( 142 time=2, 143 logits=logits, 144 next_cell_state=dummy_cell_state, 145 beam_state=beam_state, 146 batch_size=ops.convert_to_tensor(self.batch_size), 147 beam_width=self.beam_width, 148 end_token=self.end_token, 149 length_penalty_weight=self.length_penalty_weight) 150 151 with self.test_session() as sess: 152 outputs_, next_state_, state_, log_probs_ = sess.run( 153 [outputs, next_beam_state, beam_state, log_probs]) 154 155 self.assertAllEqual(outputs_.predicted_ids, [[3, 3, 2], [2, 2, 1]]) 156 self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [2, 1, 0]]) 157 self.assertAllEqual(next_state_.lengths, [[3, 3, 3], [3, 3, 3]]) 158 self.assertAllEqual(next_state_.finished, 159 [[False, False, False], [False, False, False]]) 160 161 expected_log_probs = [] 162 expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) 163 expected_log_probs.append(state_.log_probs[1][[2, 1, 0]]) # 0 --> 1 164 expected_log_probs[0][0] += log_probs_[0, 1, 3] 165 expected_log_probs[0][1] += log_probs_[0, 0, 3] 166 expected_log_probs[0][2] += log_probs_[0, 0, 2] 167 expected_log_probs[1][0] += log_probs_[1, 2, 2] 168 expected_log_probs[1][1] += log_probs_[1, 1, 2] 169 expected_log_probs[1][2] += log_probs_[1, 0, 1] 170 self.assertAllEqual(next_state_.log_probs, expected_log_probs) 171 172 def test_step_with_eos(self): 173 dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) 174 beam_state = beam_search_decoder.BeamSearchDecoderState( 175 cell_state=dummy_cell_state, 176 log_probs=nn_ops.log_softmax( 177 array_ops.ones([self.batch_size, self.beam_width])), 178 lengths=ops.convert_to_tensor( 179 [[2, 1, 2], [2, 2, 1]], dtype=dtypes.int64), 180 finished=ops.convert_to_tensor( 181 [[False, True, False], [False, False, True]], dtype=dtypes.bool)) 182 183 logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 184 0.0001) 185 logits_[0, 0, 2] = 1.9 186 logits_[0, 0, 3] = 2.1 187 logits_[0, 1, 3] = 3.1 188 logits_[0, 1, 4] = 0.9 189 logits_[1, 0, 1] = 0.5 190 logits_[1, 1, 2] = 5.7 # why does this not work when it's 2.7? 191 logits_[1, 2, 2] = 1.0 192 logits_[1, 2, 3] = 0.2 193 logits = ops.convert_to_tensor(logits_, dtype=dtypes.float32) 194 log_probs = nn_ops.log_softmax(logits) 195 196 outputs, next_beam_state = beam_search_decoder._beam_search_step( 197 time=2, 198 logits=logits, 199 next_cell_state=dummy_cell_state, 200 beam_state=beam_state, 201 batch_size=ops.convert_to_tensor(self.batch_size), 202 beam_width=self.beam_width, 203 end_token=self.end_token, 204 length_penalty_weight=self.length_penalty_weight) 205 206 with self.test_session() as sess: 207 outputs_, next_state_, state_, log_probs_ = sess.run( 208 [outputs, next_beam_state, beam_state, log_probs]) 209 210 self.assertAllEqual(outputs_.parent_ids, [[1, 0, 0], [1, 2, 0]]) 211 self.assertAllEqual(outputs_.predicted_ids, [[0, 3, 2], [2, 0, 1]]) 212 self.assertAllEqual(next_state_.lengths, [[1, 3, 3], [3, 1, 3]]) 213 self.assertAllEqual(next_state_.finished, 214 [[True, False, False], [False, True, False]]) 215 216 expected_log_probs = [] 217 expected_log_probs.append(state_.log_probs[0][[1, 0, 0]]) 218 expected_log_probs.append(state_.log_probs[1][[1, 2, 0]]) 219 expected_log_probs[0][1] += log_probs_[0, 0, 3] 220 expected_log_probs[0][2] += log_probs_[0, 0, 2] 221 expected_log_probs[1][0] += log_probs_[1, 1, 2] 222 expected_log_probs[1][2] += log_probs_[1, 0, 1] 223 self.assertAllEqual(next_state_.log_probs, expected_log_probs) 224 225 226 class TestLargeBeamStep(test.TestCase): 227 """Tests large beam step. 228 229 Tests a single step of beam search in such case that beam size is larger than 230 vocabulary size. 231 """ 232 233 def setUp(self): 234 super(TestLargeBeamStep, self).setUp() 235 self.batch_size = 2 236 self.beam_width = 8 237 self.vocab_size = 5 238 self.end_token = 0 239 self.length_penalty_weight = 0.6 240 241 def test_step(self): 242 243 def get_probs(): 244 """this simulates the initialize method in BeamSearchDecoder.""" 245 log_prob_mask = array_ops.one_hot( 246 array_ops.zeros([self.batch_size], dtype=dtypes.int32), 247 depth=self.beam_width, 248 on_value=True, 249 off_value=False, 250 dtype=dtypes.bool) 251 252 log_prob_zeros = array_ops.zeros( 253 [self.batch_size, self.beam_width], dtype=dtypes.float32) 254 log_prob_neg_inf = array_ops.ones( 255 [self.batch_size, self.beam_width], dtype=dtypes.float32) * -np.Inf 256 257 log_probs = array_ops.where(log_prob_mask, log_prob_zeros, 258 log_prob_neg_inf) 259 return log_probs 260 261 log_probs = get_probs() 262 dummy_cell_state = array_ops.zeros([self.batch_size, self.beam_width]) 263 264 # pylint: disable=invalid-name 265 _finished = array_ops.one_hot( 266 array_ops.zeros([self.batch_size], dtype=dtypes.int32), 267 depth=self.beam_width, 268 on_value=False, 269 off_value=True, 270 dtype=dtypes.bool) 271 _lengths = np.zeros([self.batch_size, self.beam_width], dtype=np.int64) 272 _lengths[:, 0] = 2 273 _lengths = constant_op.constant(_lengths, dtype=dtypes.int64) 274 275 beam_state = beam_search_decoder.BeamSearchDecoderState( 276 cell_state=dummy_cell_state, 277 log_probs=log_probs, 278 lengths=_lengths, 279 finished=_finished) 280 281 logits_ = np.full([self.batch_size, self.beam_width, self.vocab_size], 282 0.0001) 283 logits_[0, 0, 2] = 1.9 284 logits_[0, 0, 3] = 2.1 285 logits_[0, 1, 3] = 3.1 286 logits_[0, 1, 4] = 0.9 287 logits_[1, 0, 1] = 0.5 288 logits_[1, 1, 2] = 2.7 289 logits_[1, 2, 2] = 10.0 290 logits_[1, 2, 3] = 0.2 291 logits = constant_op.constant(logits_, dtype=dtypes.float32) 292 log_probs = nn_ops.log_softmax(logits) 293 294 outputs, next_beam_state = beam_search_decoder._beam_search_step( 295 time=2, 296 logits=logits, 297 next_cell_state=dummy_cell_state, 298 beam_state=beam_state, 299 batch_size=ops.convert_to_tensor(self.batch_size), 300 beam_width=self.beam_width, 301 end_token=self.end_token, 302 length_penalty_weight=self.length_penalty_weight) 303 304 with self.test_session() as sess: 305 outputs_, next_state_, _, _ = sess.run( 306 [outputs, next_beam_state, beam_state, log_probs]) 307 308 self.assertEqual(outputs_.predicted_ids[0, 0], 3) 309 self.assertEqual(outputs_.predicted_ids[0, 1], 2) 310 self.assertEqual(outputs_.predicted_ids[1, 0], 1) 311 neg_inf = -np.Inf 312 self.assertAllEqual( 313 next_state_.log_probs[:, -3:], 314 [[neg_inf, neg_inf, neg_inf], [neg_inf, neg_inf, neg_inf]]) 315 self.assertEqual((next_state_.log_probs[:, :-3] > neg_inf).all(), True) 316 self.assertEqual((next_state_.lengths[:, :-3] > 0).all(), True) 317 self.assertAllEqual(next_state_.lengths[:, -3:], [[0, 0, 0], [0, 0, 0]]) 318 319 320 class BeamSearchDecoderTest(test.TestCase): 321 322 def _testDynamicDecodeRNN(self, time_major, has_attention): 323 encoder_sequence_length = np.array([3, 2, 3, 1, 1]) 324 decoder_sequence_length = np.array([2, 0, 1, 2, 3]) 325 batch_size = 5 326 decoder_max_time = 4 327 input_depth = 7 328 cell_depth = 9 329 attention_depth = 6 330 vocab_size = 20 331 end_token = vocab_size - 1 332 start_token = 0 333 embedding_dim = 50 334 max_out = max(decoder_sequence_length) 335 output_layer = layers_core.Dense(vocab_size, use_bias=True, activation=None) 336 beam_width = 3 337 338 with self.test_session() as sess: 339 batch_size_tensor = constant_op.constant(batch_size) 340 embedding = np.random.randn(vocab_size, embedding_dim).astype(np.float32) 341 cell = rnn_cell.LSTMCell(cell_depth) 342 initial_state = cell.zero_state(batch_size, dtypes.float32) 343 if has_attention: 344 inputs = array_ops.placeholder_with_default( 345 np.random.randn(batch_size, decoder_max_time, input_depth).astype( 346 np.float32), 347 shape=(None, None, input_depth)) 348 tiled_inputs = beam_search_decoder.tile_batch( 349 inputs, multiplier=beam_width) 350 tiled_sequence_length = beam_search_decoder.tile_batch( 351 encoder_sequence_length, multiplier=beam_width) 352 attention_mechanism = attention_wrapper.BahdanauAttention( 353 num_units=attention_depth, 354 memory=tiled_inputs, 355 memory_sequence_length=tiled_sequence_length) 356 initial_state = beam_search_decoder.tile_batch( 357 initial_state, multiplier=beam_width) 358 cell = attention_wrapper.AttentionWrapper( 359 cell=cell, 360 attention_mechanism=attention_mechanism, 361 attention_layer_size=attention_depth, 362 alignment_history=False) 363 cell_state = cell.zero_state( 364 dtype=dtypes.float32, batch_size=batch_size_tensor * beam_width) 365 if has_attention: 366 cell_state = cell_state.clone(cell_state=initial_state) 367 bsd = beam_search_decoder.BeamSearchDecoder( 368 cell=cell, 369 embedding=embedding, 370 start_tokens=array_ops.fill([batch_size_tensor], start_token), 371 end_token=end_token, 372 initial_state=cell_state, 373 beam_width=beam_width, 374 output_layer=output_layer, 375 length_penalty_weight=0.0) 376 377 final_outputs, final_state, final_sequence_lengths = ( 378 decoder.dynamic_decode( 379 bsd, output_time_major=time_major, maximum_iterations=max_out)) 380 381 def _t(shape): 382 if time_major: 383 return (shape[1], shape[0]) + shape[2:] 384 return shape 385 386 self.assertTrue( 387 isinstance(final_outputs, 388 beam_search_decoder.FinalBeamSearchDecoderOutput)) 389 self.assertTrue( 390 isinstance(final_state, beam_search_decoder.BeamSearchDecoderState)) 391 392 beam_search_decoder_output = final_outputs.beam_search_decoder_output 393 self.assertEqual( 394 _t((batch_size, None, beam_width)), 395 tuple(beam_search_decoder_output.scores.get_shape().as_list())) 396 self.assertEqual( 397 _t((batch_size, None, beam_width)), 398 tuple(final_outputs.predicted_ids.get_shape().as_list())) 399 400 sess.run(variables.global_variables_initializer()) 401 sess_results = sess.run({ 402 'final_outputs': final_outputs, 403 'final_state': final_state, 404 'final_sequence_lengths': final_sequence_lengths 405 }) 406 407 max_sequence_length = np.max(sess_results['final_sequence_lengths']) 408 409 # A smoke test 410 self.assertEqual( 411 _t((batch_size, max_sequence_length, beam_width)), 412 sess_results['final_outputs'].beam_search_decoder_output.scores.shape) 413 self.assertEqual( 414 _t((batch_size, max_sequence_length, beam_width)), sess_results[ 415 'final_outputs'].beam_search_decoder_output.predicted_ids.shape) 416 417 def testDynamicDecodeRNNBatchMajorNoAttention(self): 418 self._testDynamicDecodeRNN(time_major=False, has_attention=False) 419 420 def testDynamicDecodeRNNBatchMajorYesAttention(self): 421 self._testDynamicDecodeRNN(time_major=False, has_attention=True) 422 423 424 if __name__ == '__main__': 425 test.main() 426