Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2018 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for `tf.data.experimental.cardinality()`."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from absl.testing import parameterized
     21 
     22 from tensorflow.python.data.experimental.ops import cardinality
     23 from tensorflow.python.data.kernel_tests import test_base
     24 from tensorflow.python.data.ops import dataset_ops
     25 from tensorflow.python.framework import test_util
     26 from tensorflow.python.platform import test
     27 
     28 
     29 @test_util.run_all_in_graph_and_eager_modes
     30 class NumElementsTest(test_base.DatasetTestBase, parameterized.TestCase):
     31   """Tests for `tf.data.experimental.cardinality()`."""
     32 
     33   @parameterized.named_parameters(
     34       # pylint: disable=g-long-lambda
     35       ("Batch1",
     36        lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=True), 2),
     37       ("Batch2",
     38        lambda: dataset_ops.Dataset.range(5).batch(2, drop_remainder=False), 3),
     39       ("Batch3",
     40        lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).batch(2),
     41        cardinality.UNKNOWN),
     42       ("Batch4", lambda: dataset_ops.Dataset.range(5).repeat().batch(2),
     43        cardinality.INFINITE),
     44       ("Cache1", lambda: dataset_ops.Dataset.range(5).cache(), 5),
     45       ("Cache2", lambda: dataset_ops.Dataset.range(5).cache("foo"), 5),
     46       ("Concatenate1", lambda: dataset_ops.Dataset.range(5).concatenate(
     47           dataset_ops.Dataset.range(5)), 10),
     48       ("Concatenate2",
     49        lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
     50            dataset_ops.Dataset.range(5)), cardinality.UNKNOWN),
     51       ("Concatenate3", lambda: dataset_ops.Dataset.range(5).repeat().
     52        concatenate(dataset_ops.Dataset.range(5)), cardinality.INFINITE),
     53       ("Concatenate4", lambda: dataset_ops.Dataset.range(5).concatenate(
     54           dataset_ops.Dataset.range(5).filter(lambda _: True)),
     55        cardinality.UNKNOWN),
     56       ("Concatenate5",
     57        lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
     58            dataset_ops.Dataset.range(5).filter(lambda _: True)),
     59        cardinality.UNKNOWN),
     60       ("Concatenate6", lambda: dataset_ops.Dataset.range(5).repeat().
     61        concatenate(dataset_ops.Dataset.range(5).filter(lambda _: True)),
     62        cardinality.INFINITE),
     63       ("Concatenate7", lambda: dataset_ops.Dataset.range(5).concatenate(
     64           dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE),
     65       ("Concatenate8",
     66        lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).concatenate(
     67            dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE),
     68       ("Concatenate9",
     69        lambda: dataset_ops.Dataset.range(5).repeat().concatenate(
     70            dataset_ops.Dataset.range(5).repeat()), cardinality.INFINITE),
     71       ("FlatMap", lambda: dataset_ops.Dataset.range(5).flat_map(
     72           lambda _: dataset_ops.Dataset.from_tensors(0)), cardinality.UNKNOWN),
     73       ("Filter", lambda: dataset_ops.Dataset.range(5).filter(lambda _: True),
     74        cardinality.UNKNOWN),
     75       ("FromTensors1", lambda: dataset_ops.Dataset.from_tensors(0), 1),
     76       ("FromTensors2", lambda: dataset_ops.Dataset.from_tensors((0, 1)), 1),
     77       ("FromTensorSlices1",
     78        lambda: dataset_ops.Dataset.from_tensor_slices([0, 0, 0]), 3),
     79       ("FromTensorSlices2",
     80        lambda: dataset_ops.Dataset.from_tensor_slices(([0, 0, 0], [1, 1, 1])),
     81        3),
     82       ("Interleave1", lambda: dataset_ops.Dataset.range(5).interleave(
     83           lambda _: dataset_ops.Dataset.from_tensors(0), cycle_length=1),
     84        cardinality.UNKNOWN),
     85       ("Interleave2", lambda: dataset_ops.Dataset.range(5).interleave(
     86           lambda _: dataset_ops.Dataset.from_tensors(0),
     87           cycle_length=1,
     88           num_parallel_calls=1), cardinality.UNKNOWN),
     89       ("Map1", lambda: dataset_ops.Dataset.range(5).map(lambda x: x), 5),
     90       ("Map2", lambda: dataset_ops.Dataset.range(5).map(
     91           lambda x: x, num_parallel_calls=1), 5),
     92       ("PaddedBatch1", lambda: dataset_ops.Dataset.range(5).padded_batch(
     93           2, [], drop_remainder=True), 2),
     94       ("PaddedBatch2", lambda: dataset_ops.Dataset.range(5).padded_batch(
     95           2, [], drop_remainder=False), 3),
     96       ("PaddedBatch3", lambda: dataset_ops.Dataset.range(5).filter(
     97           lambda _: True).padded_batch(2, []), cardinality.UNKNOWN),
     98       ("PaddedBatch4",
     99        lambda: dataset_ops.Dataset.range(5).repeat().padded_batch(2, []),
    100        cardinality.INFINITE),
    101       ("Prefetch", lambda: dataset_ops.Dataset.range(5).prefetch(buffer_size=1),
    102        5),
    103       ("Range1", lambda: dataset_ops.Dataset.range(0), 0),
    104       ("Range2", lambda: dataset_ops.Dataset.range(5), 5),
    105       ("Range3", lambda: dataset_ops.Dataset.range(5, 10), 5),
    106       ("Range4", lambda: dataset_ops.Dataset.range(10, 5), 0),
    107       ("Range5", lambda: dataset_ops.Dataset.range(5, 10, 2), 3),
    108       ("Range6", lambda: dataset_ops.Dataset.range(10, 5, -2), 3),
    109       ("Repeat1", lambda: dataset_ops.Dataset.range(0).repeat(0), 0),
    110       ("Repeat2", lambda: dataset_ops.Dataset.range(1).repeat(0), 0),
    111       ("Repeat3", lambda: dataset_ops.Dataset.range(0).repeat(5), 0),
    112       ("Repeat4", lambda: dataset_ops.Dataset.range(1).repeat(5), 5),
    113       ("Repeat5", lambda: dataset_ops.Dataset.range(0).repeat(), 0),
    114       ("Repeat6", lambda: dataset_ops.Dataset.range(1).repeat(),
    115        cardinality.INFINITE),
    116       ("Shuffle", lambda: dataset_ops.Dataset.range(5).shuffle(buffer_size=1),
    117        5),
    118       ("Shard1", lambda: dataset_ops.Dataset.range(5).shard(2, 0), 3),
    119       ("Shard2", lambda: dataset_ops.Dataset.range(5).shard(8, 7), 0),
    120       ("Shard3",
    121        lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).shard(2, 0),
    122        cardinality.UNKNOWN),
    123       ("Shard4", lambda: dataset_ops.Dataset.range(5).repeat().shard(2, 0),
    124        cardinality.INFINITE),
    125       ("Skip1", lambda: dataset_ops.Dataset.range(5).skip(2), 3),
    126       ("Skip2", lambda: dataset_ops.Dataset.range(5).skip(8), 0),
    127       ("Skip3",
    128        lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).skip(2),
    129        cardinality.UNKNOWN),
    130       ("Skip4", lambda: dataset_ops.Dataset.range(5).repeat().skip(2),
    131        cardinality.INFINITE),
    132       ("Take1", lambda: dataset_ops.Dataset.range(5).take(2), 2),
    133       ("Take2", lambda: dataset_ops.Dataset.range(5).take(8), 5),
    134       ("Take3",
    135        lambda: dataset_ops.Dataset.range(5).filter(lambda _: True).take(2),
    136        cardinality.UNKNOWN),
    137       ("Take4", lambda: dataset_ops.Dataset.range(5).repeat().take(2), 2),
    138       ("Window1", lambda: dataset_ops.Dataset.range(5).window(
    139           size=2, shift=2, drop_remainder=True), 2),
    140       ("Window2", lambda: dataset_ops.Dataset.range(5).window(
    141           size=2, shift=2, drop_remainder=False), 3),
    142       ("Zip1", lambda: dataset_ops.Dataset.zip(dataset_ops.Dataset.range(5)),
    143        5),
    144       ("Zip2", lambda: dataset_ops.Dataset.zip(
    145           (dataset_ops.Dataset.range(5), dataset_ops.Dataset.range(3))), 3),
    146       ("Zip3", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
    147           5), dataset_ops.Dataset.range(3).repeat())), 5),
    148       ("Zip4", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
    149           5).repeat(), dataset_ops.Dataset.range(3).repeat())),
    150        cardinality.INFINITE),
    151       ("Zip5", lambda: dataset_ops.Dataset.zip((dataset_ops.Dataset.range(
    152           5), dataset_ops.Dataset.range(3).filter(lambda _: True))),
    153        cardinality.UNKNOWN),
    154       # pylint: enable=g-long-lambda
    155   )
    156   def testNumElements(self, dataset_fn, expected_result):
    157     with self.cached_session() as sess:
    158       self.assertEqual(
    159           sess.run(cardinality.cardinality(dataset_fn())), expected_result)
    160 
    161 
    162 if __name__ == "__main__":
    163   test.main()
    164