Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for contrib.seq2seq.python.seq2seq.beam_search_ops."""
     16 # pylint: disable=unused-import,g-bad-import-order
     17 from __future__ import absolute_import
     18 from __future__ import division
     19 from __future__ import print_function
     20 # pylint: enable=unused-import
     21 
     22 import itertools
     23 
     24 import numpy as np
     25 
     26 from tensorflow.contrib.seq2seq.python.ops import beam_search_ops
     27 from tensorflow.python.framework import ops
     28 from tensorflow.python.platform import test
     29 
     30 
     31 def _transpose_batch_time(x):
     32   return np.transpose(x, [1, 0, 2]).astype(np.int32)
     33 
     34 
     35 class GatherTreeTest(test.TestCase):
     36 
     37   def testGatherTreeOne(self):
     38     # (max_time = 4, batch_size = 1, beams = 3)
     39     end_token = 10
     40     step_ids = _transpose_batch_time(
     41         [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
     42     parent_ids = _transpose_batch_time(
     43         [[[0, 0, 0], [0, 1, 1], [2, 1, 2], [-1, -1, -1]]])
     44     max_sequence_lengths = [3]
     45     expected_result = _transpose_batch_time([[[2, 2, 2], [6, 5, 6], [7, 8, 9],
     46                                               [10, 10, 10]]])
     47     beams = beam_search_ops.gather_tree(
     48         step_ids=step_ids,
     49         parent_ids=parent_ids,
     50         max_sequence_lengths=max_sequence_lengths,
     51         end_token=end_token)
     52     with self.test_session(use_gpu=True):
     53       self.assertAllEqual(expected_result, beams.eval())
     54 
     55   def testBadParentValuesOnCPU(self):
     56     # (batch_size = 1, max_time = 4, beams = 3)
     57     # bad parent in beam 1 time 1
     58     end_token = 10
     59     step_ids = _transpose_batch_time(
     60         [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
     61     parent_ids = _transpose_batch_time(
     62         [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
     63     max_sequence_lengths = [3]
     64     with ops.device("/cpu:0"):
     65       beams = beam_search_ops.gather_tree(
     66           step_ids=step_ids,
     67           parent_ids=parent_ids,
     68           max_sequence_lengths=max_sequence_lengths,
     69           end_token=end_token)
     70     with self.test_session():
     71       with self.assertRaisesOpError(
     72           r"parent id -1 at \(batch, time, beam\) == \(0, 0, 1\)"):
     73         _ = beams.eval()
     74 
     75   def testBadParentValuesOnGPU(self):
     76     # Only want to run this test on CUDA devices, as gather_tree is not
     77     # registered for SYCL devices.
     78     if not test.is_gpu_available(cuda_only=True):
     79       return
     80     # (max_time = 4, batch_size = 1, beams = 3)
     81     # bad parent in beam 1 time 1; appears as a negative index at time 0
     82     end_token = 10
     83     step_ids = _transpose_batch_time(
     84         [[[1, 2, 3], [4, 5, 6], [7, 8, 9], [-1, -1, -1]]])
     85     parent_ids = _transpose_batch_time(
     86         [[[0, 0, 0], [0, -1, 1], [2, 1, 2], [-1, -1, -1]]])
     87     max_sequence_lengths = [3]
     88     expected_result = _transpose_batch_time([[[2, -1, 2], [6, 5, 6], [7, 8, 9],
     89                                               [10, 10, 10]]])
     90     with ops.device("/device:GPU:0"):
     91       beams = beam_search_ops.gather_tree(
     92           step_ids=step_ids,
     93           parent_ids=parent_ids,
     94           max_sequence_lengths=max_sequence_lengths,
     95           end_token=end_token)
     96     with self.test_session(use_gpu=True):
     97       self.assertAllEqual(expected_result, beams.eval())
     98 
     99   def testGatherTreeBatch(self):
    100     batch_size = 10
    101     beam_width = 15
    102     max_time = 8
    103     max_sequence_lengths = [0, 1, 2, 4, 7, 8, 9, 10, 11, 0]
    104     end_token = 5
    105 
    106     with self.test_session(use_gpu=True):
    107       step_ids = np.random.randint(
    108           0, high=end_token + 1, size=(max_time, batch_size, beam_width))
    109       parent_ids = np.random.randint(
    110           0, high=beam_width - 1, size=(max_time, batch_size, beam_width))
    111 
    112       beams = beam_search_ops.gather_tree(
    113           step_ids=step_ids.astype(np.int32),
    114           parent_ids=parent_ids.astype(np.int32),
    115           max_sequence_lengths=max_sequence_lengths,
    116           end_token=end_token)
    117 
    118       self.assertEqual((max_time, batch_size, beam_width), beams.shape)
    119       beams_value = beams.eval()
    120       for b in range(batch_size):
    121         # Past max_sequence_lengths[b], we emit all end tokens.
    122         b_value = beams_value[max_sequence_lengths[b]:, b, :]
    123         self.assertAllClose(b_value, end_token * np.ones_like(b_value))
    124       for batch, beam in itertools.product(
    125           range(batch_size), range(beam_width)):
    126         v = np.squeeze(beams_value[:, batch, beam])
    127         if end_token in v:
    128           found_bad = np.where(v == -1)[0]
    129           self.assertEqual(0, len(found_bad))
    130           found = np.where(v == end_token)[0]
    131           found = found[0]  # First occurrence of end_token.
    132           # If an end_token is found, everything before it should be a
    133           # valid id and everything after it should be -1.
    134           if found > 0:
    135             self.assertAllEqual(
    136                 v[:found - 1] >= 0, np.ones_like(v[:found - 1], dtype=bool))
    137           self.assertAllClose(v[found + 1:],
    138                               end_token * np.ones_like(v[found + 1:]))
    139 
    140 
    141 if __name__ == "__main__":
    142   test.main()
    143