Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2019 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for `tf.data.experimental.take_while()`."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from absl.testing import parameterized
     21 import numpy as np
     22 
     23 from tensorflow.python.data.experimental.ops import take_while_ops
     24 from tensorflow.python.data.kernel_tests import test_base
     25 from tensorflow.python.data.ops import dataset_ops
     26 from tensorflow.python.framework import constant_op
     27 from tensorflow.python.framework import errors
     28 from tensorflow.python.framework import test_util
     29 from tensorflow.python.ops import array_ops
     30 from tensorflow.python.ops import math_ops
     31 from tensorflow.python.platform import test
     32 
     33 
     34 @test_util.run_all_in_graph_and_eager_modes
     35 class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase):
     36 
     37   @parameterized.parameters((14, 2), (15, 2), (100, 3))
     38   def testTakeWhileDataset(self, num_elements, window_size):
     39 
     40     def _predicate_func(elem):
     41       return array_ops.shape(elem)[0] > (window_size - 1)
     42 
     43     take_while = take_while_ops.take_while(_predicate_func)
     44 
     45     dataset = dataset_ops.Dataset.range(num_elements).batch(window_size)
     46     dataset = dataset.apply(take_while).flat_map(
     47         dataset_ops.Dataset.from_tensor_slices)
     48 
     49     expected_num_elements = int(num_elements / window_size) * window_size
     50     self.assertDatasetProduces(dataset, np.arange(expected_num_elements))
     51 
     52   @parameterized.parameters((10, 2, False), (16, 7, False), (100, 99, False),
     53                             (100, 101, True), (0, 1, True))
     54   def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds):
     55     dataset = dataset_ops.Dataset.range(num_elements).apply(
     56         take_while_ops.take_while(lambda x: x < upper_bound))
     57 
     58     if out_of_bounds:
     59       with self.assertRaises(errors.OutOfRangeError):
     60         self.assertDatasetProduces(dataset, np.arange(upper_bound))
     61 
     62     else:
     63       self.assertDatasetProduces(dataset, np.arange(upper_bound))
     64 
     65   def testTakeWhileDatasetString(self):
     66 
     67     def not_equal(string):
     68       return lambda x: math_ops.not_equal(x, constant_op.constant(string))
     69 
     70     string = ["this", "is", "the", "test", "for", "strings"]
     71     dataset = dataset_ops.Dataset.from_tensor_slices(string).apply(
     72         take_while_ops.take_while(not_equal("test")))
     73 
     74     next_element = self.getNext(dataset)
     75     self.assertEqual(b"this", self.evaluate(next_element()))
     76     self.assertEqual(b"is", self.evaluate(next_element()))
     77     self.assertEqual(b"the", self.evaluate(next_element()))
     78 
     79     with self.assertRaises(errors.OutOfRangeError):
     80       self.assertEqual(b"test", self.evaluate(next_element()))
     81 
     82   @parameterized.parameters((5, 3), (10, 0), (100, 5), (8, 7))
     83   def testTakewhileDatasetShortCircuit(self, size, index):
     84 
     85     def _predicate_func(data_elem):
     86       return data_elem
     87 
     88     boolean_array = [True] * size
     89     boolean_array[index] = False
     90     dataset = dataset_ops.Dataset.from_tensor_slices(boolean_array).apply(
     91         take_while_ops.take_while(_predicate_func))
     92 
     93     next_element = self.getNext(dataset)
     94 
     95     for _ in range(index):
     96       self.assertTrue(self.evaluate(next_element()))
     97 
     98     with self.assertRaises(errors.OutOfRangeError):
     99       self.evaluate(next_element())
    100 
    101 
    102 if __name__ == "__main__":
    103   test.main()
    104