1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for `tf.data.experimental.cardinality()`.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from absl.testing import parameterized 21 22 from tensorflow.python.data.experimental.ops import cardinality 23 from tensorflow.python.data.kernel_tests import test_base 24 from tensorflow.python.data.ops import dataset_ops 25 from tensorflow.python.framework import test_util 26 from tensorflow.python.platform import test 27 28 29 @test_util.run_all_in_graph_and_eager_modes 30 class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase): 31 """Tests for `tf.data.experimental.cardinality()`.""" 32 33 @parameterized.named_parameters( 34 # pylint: disable=g-long-lambda 35 ("Batch1", 36 lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2), 37 ("Batch2", 38 lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=False), 3), 39 ("Batch3", 40 lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).batch(2), 41 cardinality.UNKNOWN), 42 ("Batch4", lambda: dataset_ops.Dataset.range(5).repeat().batch(2), 43 cardinality.INFINITE), 44 ("Cache1", lambda: dataset_ops.Dataset.range(5).cache(), 5), 45 ("Cache2", lambda: dataset_ops.Dataset.range(5).cache("foo"), 5), 46 ("Concatenate1", lambda: dataset_ops.Dataset.range(5).concatenate( 47 dataset_ops.Dataset.range(5)), 10), 48 ("Concatenate2", 49 lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate( 50 dataset_ops.Dataset.range(5)), cardinality.UNKNOWN), 51 ("Concatenate3", lambda: dataset_ops.Dataset.range(5).repeat(). 52 concatenate(dataset_ops.Dataset.range(5)), cardinality.INFINITE), 53 ("Concatenate4", lambda: dataset_ops.Dataset.range(5).concatenate( 54 dataset_ops.Dataset.range(5).filter(lambda _: True)), 55 cardinality.UNKNOWN), 56 ("Concatenate5", 57 lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate( 58 dataset_ops.Dataset.range(5).filter(lambda _: True)), 59 cardinality.UNKNOWN), 60 ("Concatenate6", lambda: dataset_ops.Dataset.range(5).repeat(). 61 concatenate(dataset_ops.Dataset.range(5).filter(lambda _: True)), 62 cardinality.INFINITE), 63 ("Concatenate7", lambda: dataset_ops.Dataset.range(5).concatenate( 64 dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE), 65 ("Concatenate8", 66 lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate( 67 dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE), 68 ("Concatenate9", 69 lambda: dataset_ops.Dataset.range(5).repeat().concatenate( 70 dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE), 71 ("FlatMap", lambda: dataset_ops.Dataset.range(5).flat_map( 72 lambda _: dataset_ops.Dataset.from_tensors(0)), cardinality.UNKNOWN), 73 ("Filter", lambda: dataset_ops.Dataset.range(5).filter(lambda _: True), 74 cardinality.UNKNOWN), 75 ("FromTensors1", lambda: dataset_ops.Dataset.from_tensors(0), 1), 76 ("FromTensors2", lambda: dataset_ops.Dataset.from_tensors((0, 1)), 1), 77 ("FromTensorSlices1", 78 lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0]), 3), 79 ("FromTensorSlices2", 80 lambda: dataset_ops.Dataset.from_tensor_slices(([0, 0, 0], [1, 1, 1])), 81 3), 82 ("Interleave1", lambda: dataset_ops.Dataset.range(5).interleave( 83 lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1), 84 cardinality.UNKNOWN), 85 ("Interleave2", lambda: dataset_ops.Dataset.range(5).interleave( 86 lambda _: dataset_ops.Dataset.from_tensors(0), 87 cycle_length=1, 88 num_parallel_calls=1), cardinality.UNKNOWN), 89 ("Map1", lambda: dataset_ops.Dataset.range(5).map(lambda x: x), 5), 90 ("Map2", lambda: dataset_ops.Dataset.range(5).map( 91 lambda x: x, num_parallel_calls=1), 5), 92 ("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch( 93 2, [], drop_remainder=True), 2), 94 ("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch( 95 2, [], drop_remainder=False), 3), 96 ("PaddedBatch3", lambda: dataset_ops.Dataset.range(5).filter( 97 lambda _: True).padded_batch(2, []), cardinality.UNKNOWN), 98 ("PaddedBatch4", 99 lambda: dataset_ops.Dataset.range(5).repeat().padded_batch(2, []), 100 cardinality.INFINITE), 101 ("Prefetch", lambda: dataset_ops.Dataset.range(5).prefetch(buffer_size=1), 102 5), 103 ("Range1", lambda: dataset_ops.Dataset.range(0), 0), 104 ("Range2", lambda: dataset_ops.Dataset.range(5), 5), 105 ("Range3", lambda: dataset_ops.Dataset.range(5, 10), 5), 106 ("Range4", lambda: dataset_ops.Dataset.range(10, 5), 0), 107 ("Range5", lambda: dataset_ops.Dataset.range(5, 10, 2), 3), 108 ("Range6", lambda: dataset_ops.Dataset.range(10, 5, -2), 3), 109 ("Repeat1", lambda: dataset_ops.Dataset.range(0).repeat(0), 0), 110 ("Repeat2", lambda: dataset_ops.Dataset.range(1).repeat(0), 0), 111 ("Repeat3", lambda: dataset_ops.Dataset.range(0).repeat(5), 0), 112 ("Repeat4", lambda: dataset_ops.Dataset.range(1).repeat(5), 5), 113 ("Repeat5", lambda: dataset_ops.Dataset.range(0).repeat(), 0), 114 ("Repeat6", lambda: dataset_ops.Dataset.range(1).repeat(), 115 cardinality.INFINITE), 116 ("Shuffle", lambda: dataset_ops.Dataset.range(5).shuffle(buffer_size=1), 117 5), 118 ("Shard1", lambda: dataset_ops.Dataset.range(5).shard(2, 0), 3), 119 ("Shard2", lambda: dataset_ops.Dataset.range(5).shard(8, 7), 0), 120 ("Shard3", 121 lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).shard(2, 0), 122 cardinality.UNKNOWN), 123 ("Shard4", lambda: dataset_ops.Dataset.range(5).repeat().shard(2, 0), 124 cardinality.INFINITE), 125 ("Skip1", lambda: dataset_ops.Dataset.range(5).skip(2), 3), 126 ("Skip2", lambda: dataset_ops.Dataset.range(5).skip(8), 0), 127 ("Skip3", 128 lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).skip(2), 129 cardinality.UNKNOWN), 130 ("Skip4", lambda: dataset_ops.Dataset.range(5).repeat().skip(2), 131 cardinality.INFINITE), 132 ("Take1", lambda: dataset_ops.Dataset.range(5).take(2), 2), 133 ("Take2", lambda: dataset_ops.Dataset.range(5).take(8), 5), 134 ("Take3", 135 lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).take(2), 136 cardinality.UNKNOWN), 137 ("Take4", lambda: dataset_ops.Dataset.range(5).repeat().take(2), 2), 138 ("Window1", lambda: dataset_ops.Dataset.range(5).window( 139 size=2, shift=2, drop_remainder=True), 2), 140 ("Window2", lambda: dataset_ops.Dataset.range(5).window( 141 size=2, shift=2, drop_remainder=False), 3), 142 ("Zip1", lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5)), 143 5), 144 ("Zip2", lambda: dataset_ops.Dataset.zip( 145 (dataset_ops.Dataset.range(5), dataset_ops.Dataset.range(3))), 3), 146 ("Zip3", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range( 147 5), dataset_ops.Dataset.range(3).repeat())), 5), 148 ("Zip4", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range( 149 5).repeat(), dataset_ops.Dataset.range(3).repeat())), 150 cardinality.INFINITE), 151 ("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range( 152 5), dataset_ops.Dataset.range(3).filter(lambda _: True))), 153 cardinality.UNKNOWN), 154 # pylint: enable=g-long-lambda 155 ) 156 def testNumElements(self, dataset_fn, expected_result): 157 with self.cached_session() as sess: 158 self.assertEqual( 159 sess.run(cardinality.cardinality(dataset_fn())), expected_result) 160 161 162 if __name__ == "__main__": 163 test.main() 164