1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for contrib.seq2seq.python.seq2seq.beam_search_ops.""" 16 # pylint: disable=unused-import,g-bad-import-order 17 from __future__ import absolute_import 18 from __future__ import division 19 from __future__ import print_function 20 # pylint: enable=unused-import 21 22 import itertools 23 24 import numpy as np 25 26 from tensorflow.contrib.seq2seq.python.ops import beam_search_ops 27 from tensorflow.python.framework import ops 28 from tensorflow.python.platform import test 29 30 31 def _transpose_batch_time(x): 32 return np.transpose(x, [1, 0, 2]).astype(np.int32) 33 34 35 class GatherTreeTest(test.TestCase): 36 37 def testGatherTreeOne(self): 38 # (max_time = 4, batch_size = 1, beams = 3) 39 end_token = 10 40 step_ids = _transpose_batch_time( 41 [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) 42 parent_ids = _transpose_batch_time( 43 [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]]) 44 max_sequence_lengths = [3] 45 expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9], 46 [10, 10, 10]]]) 47 beams = beam_search_ops.gather_tree( 48 step_ids=step_ids, 49 parent_ids=parent_ids, 50 max_sequence_lengths=max_sequence_lengths, 51 end_token=end_token) 52 with self.test_session(use_gpu=True): 53 self.assertAllEqual(expected_result, beams.eval()) 54 55 def testBadParentValuesOnCPU(self): 56 # (batch_size = 1, max_time = 4, beams = 3) 57 # bad parent in beam 1 time 1 58 end_token = 10 59 step_ids = _transpose_batch_time( 60 [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) 61 parent_ids = _transpose_batch_time( 62 [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) 63 max_sequence_lengths = [3] 64 with ops.device("/cpu:0"): 65 beams = beam_search_ops.gather_tree( 66 step_ids=step_ids, 67 parent_ids=parent_ids, 68 max_sequence_lengths=max_sequence_lengths, 69 end_token=end_token) 70 with self.test_session(): 71 with self.assertRaisesOpError( 72 r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"): 73 _ = beams.eval() 74 75 def testBadParentValuesOnGPU(self): 76 # Only want to run this test on CUDA devices, as gather_tree is not 77 # registered for SYCL devices. 78 if not test.is_gpu_available(cuda_only=True): 79 return 80 # (max_time = 4, batch_size = 1, beams = 3) 81 # bad parent in beam 1 time 1; appears as a negative index at time 0 82 end_token = 10 83 step_ids = _transpose_batch_time( 84 [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]]) 85 parent_ids = _transpose_batch_time( 86 [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]]) 87 max_sequence_lengths = [3] 88 expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9], 89 [10, 10, 10]]]) 90 with ops.device("/device:GPU:0"): 91 beams = beam_search_ops.gather_tree( 92 step_ids=step_ids, 93 parent_ids=parent_ids, 94 max_sequence_lengths=max_sequence_lengths, 95 end_token=end_token) 96 with self.test_session(use_gpu=True): 97 self.assertAllEqual(expected_result, beams.eval()) 98 99 def testGatherTreeBatch(self): 100 batch_size = 10 101 beam_width = 15 102 max_time = 8 103 max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0] 104 end_token = 5 105 106 with self.test_session(use_gpu=True): 107 step_ids = np.random.randint( 108 0, high=end_token + 1, size=(max_time, batch_size, beam_width)) 109 parent_ids = np.random.randint( 110 0, high=beam_width - 1, size=(max_time, batch_size, beam_width)) 111 112 beams = beam_search_ops.gather_tree( 113 step_ids=step_ids.astype(np.int32), 114 parent_ids=parent_ids.astype(np.int32), 115 max_sequence_lengths=max_sequence_lengths, 116 end_token=end_token) 117 118 self.assertEqual((max_time, batch_size, beam_width), beams.shape) 119 beams_value = beams.eval() 120 for b in range(batch_size): 121 # Past max_sequence_lengths[b], we emit all end tokens. 122 b_value = beams_value[max_sequence_lengths[b]:, b, :] 123 self.assertAllClose(b_value, end_token * np.ones_like(b_value)) 124 for batch, beam in itertools.product( 125 range(batch_size), range(beam_width)): 126 v = np.squeeze(beams_value[:, batch, beam]) 127 if end_token in v: 128 found_bad = np.where(v == -1)[0] 129 self.assertEqual(0, len(found_bad)) 130 found = np.where(v == end_token)[0] 131 found = found[0] # First occurrence of end_token. 132 # If an end_token is found, everything before it should be a 133 # valid id and everything after it should be -1. 134 if found > 0: 135 self.assertAllEqual( 136 v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool)) 137 self.assertAllClose(v[found + 1:], 138 end_token * np.ones_like(v[found + 1:])) 139 140 141 if __name__ == "__main__": 142 test.main() 143