1 # Copyright 2019 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for `tf.data.experimental.take_while()`.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from absl.testing import parameterized 21 import numpy as np 22 23 from tensorflow.python.data.experimental.ops import take_while_ops 24 from tensorflow.python.data.kernel_tests import test_base 25 from tensorflow.python.data.ops import dataset_ops 26 from tensorflow.python.framework import constant_op 27 from tensorflow.python.framework import errors 28 from tensorflow.python.framework import test_util 29 from tensorflow.python.ops import array_ops 30 from tensorflow.python.ops import math_ops 31 from tensorflow.python.platform import test 32 33 34 @test_util.run_all_in_graph_and_eager_modes 35 class TakeWhileTest(test_base.DatasetTestBase, parameterized.TestCase): 36 37 @parameterized.parameters((14, 2), (15, 2), (100, 3)) 38 def testTakeWhileDataset(self, num_elements, window_size): 39 40 def _predicate_func(elem): 41 return array_ops.shape(elem)[0] > (window_size - 1) 42 43 take_while = take_while_ops.take_while(_predicate_func) 44 45 dataset = dataset_ops.Dataset.range(num_elements).batch(window_size) 46 dataset = dataset.apply(take_while).flat_map( 47 dataset_ops.Dataset.from_tensor_slices) 48 49 expected_num_elements = int(num_elements / window_size) * window_size 50 self.assertDatasetProduces(dataset, np.arange(expected_num_elements)) 51 52 @parameterized.parameters((10, 2, False), (16, 7, False), (100, 99, False), 53 (100, 101, True), (0, 1, True)) 54 def testTakeWhileDatasetRange(self, num_elements, upper_bound, out_of_bounds): 55 dataset = dataset_ops.Dataset.range(num_elements).apply( 56 take_while_ops.take_while(lambda x: x < upper_bound)) 57 58 if out_of_bounds: 59 with self.assertRaises(errors.OutOfRangeError): 60 self.assertDatasetProduces(dataset, np.arange(upper_bound)) 61 62 else: 63 self.assertDatasetProduces(dataset, np.arange(upper_bound)) 64 65 def testTakeWhileDatasetString(self): 66 67 def not_equal(string): 68 return lambda x: math_ops.not_equal(x, constant_op.constant(string)) 69 70 string = ["this", "is", "the", "test", "for", "strings"] 71 dataset = dataset_ops.Dataset.from_tensor_slices(string).apply( 72 take_while_ops.take_while(not_equal("test"))) 73 74 next_element = self.getNext(dataset) 75 self.assertEqual(b"this", self.evaluate(next_element())) 76 self.assertEqual(b"is", self.evaluate(next_element())) 77 self.assertEqual(b"the", self.evaluate(next_element())) 78 79 with self.assertRaises(errors.OutOfRangeError): 80 self.assertEqual(b"test", self.evaluate(next_element())) 81 82 @parameterized.parameters((5, 3), (10, 0), (100, 5), (8, 7)) 83 def testTakewhileDatasetShortCircuit(self, size, index): 84 85 def _predicate_func(data_elem): 86 return data_elem 87 88 boolean_array = [True] * size 89 boolean_array[index] = False 90 dataset = dataset_ops.Dataset.from_tensor_slices(boolean_array).apply( 91 take_while_ops.take_while(_predicate_func)) 92 93 next_element = self.getNext(dataset) 94 95 for _ in range(index): 96 self.assertTrue(self.evaluate(next_element())) 97 98 with self.assertRaises(errors.OutOfRangeError): 99 self.evaluate(next_element()) 100 101 102 if __name__ == "__main__": 103 test.main() 104