Home | History | Annotate | Download | only in kernel_tests
      1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved.
      2 #
      3 # Licensed under the Apache License, Version 2.0 (the "License");
      4 # you may not use this file except in compliance with the License.
      5 # You may obtain a copy of the License at
      6 #
      7 #     http://www.apache.org/licenses/LICENSE-2.0
      8 #
      9 # Unless required by applicable law or agreed to in writing, software
     10 # distributed under the License is distributed on an "AS IS" BASIS,
     11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
     12 # See the License for the specific language governing permissions and
     13 # limitations under the License.
     14 # ==============================================================================
     15 """Tests for the experimental input pipeline ops."""
     16 from __future__ import absolute_import
     17 from __future__ import division
     18 from __future__ import print_function
     19 
     20 from tensorflow.python.data.ops import dataset_ops
     21 from tensorflow.python.framework import errors
     22 from tensorflow.python.platform import test
     23 
     24 
     25 class ShardDatasetOpTest(test.TestCase):
     26 
     27   def testSimpleCase(self):
     28     dataset = dataset_ops.Dataset.range(10).shard(5, 2)
     29     iterator = dataset.make_one_shot_iterator()
     30 
     31     with self.test_session() as sess:
     32       self.assertEqual(2, sess.run(iterator.get_next()))
     33       self.assertEqual(7, sess.run(iterator.get_next()))
     34       with self.assertRaises(errors.OutOfRangeError):
     35         sess.run(iterator.get_next())
     36 
     37   def testNestedData(self):
     38     dataset_a = dataset_ops.Dataset.range(10)
     39     dataset_b = dataset_ops.Dataset.range(10, 0, -1)
     40     dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2)
     41     iterator = dataset.make_one_shot_iterator()
     42 
     43     with self.test_session() as sess:
     44       self.assertEqual((2, 8), sess.run(iterator.get_next()))
     45       self.assertEqual((7, 3), sess.run(iterator.get_next()))
     46       with self.assertRaises(errors.OutOfRangeError):
     47         sess.run(iterator.get_next())
     48 
     49   def testOffsetZero(self):
     50     dataset = dataset_ops.Dataset.range(10).shard(5, 0)
     51     iterator = dataset.make_one_shot_iterator()
     52 
     53     with self.test_session() as sess:
     54       self.assertEqual(0, sess.run(iterator.get_next()))
     55       self.assertEqual(5, sess.run(iterator.get_next()))
     56       with self.assertRaises(errors.OutOfRangeError):
     57         sess.run(iterator.get_next())
     58 
     59   def testOffsetGreaterNumShards(self):
     60     with self.assertRaises(ValueError):
     61       dataset_ops.Dataset.range(10).shard(5, 7)
     62 
     63   def testNegativeOffset(self):
     64     with self.assertRaises(ValueError):
     65       dataset_ops.Dataset.range(10).shard(5, -3)
     66 
     67   def testNegativeNumShards(self):
     68     with self.assertRaises(ValueError):
     69       dataset_ops.Dataset.range(10).shard(-3, 1)
     70 
     71   def testZeroNumShards(self):
     72     with self.assertRaises(ValueError):
     73       dataset_ops.Dataset.range(10).shard(0, 1)
     74 
     75   def testIteratorEndsBeforeFirstElem(self):
     76     dataset = dataset_ops.Dataset.range(1).shard(5, 2)
     77     iterator = dataset.make_one_shot_iterator()
     78 
     79     with self.test_session() as sess:
     80       with self.assertRaises(errors.OutOfRangeError):
     81         sess.run(iterator.get_next())
     82 
     83   def testLargerWorkerPool(self):
     84     dataset = dataset_ops.Dataset.range(10).shard(7, 5)
     85     iterator = dataset.make_one_shot_iterator()
     86     with self.test_session() as sess:
     87       self.assertEqual(5, sess.run(iterator.get_next()))
     88       with self.assertRaises(errors.OutOfRangeError):
     89         sess.run(iterator.get_next())
     90 
     91   def testIndexEqualsNumShards(self):
     92     dataset = dataset_ops.Dataset.range(10).shard(5, 4)
     93     iterator = dataset.make_one_shot_iterator()
     94     with self.test_session() as sess:
     95       self.assertEqual(4, sess.run(iterator.get_next()))
     96       self.assertEqual(9, sess.run(iterator.get_next()))
     97       with self.assertRaises(errors.OutOfRangeError):
     98         sess.run(iterator.get_next())
     99 
    100   def testIndexEqualsNumShards2(self):
    101     dataset = dataset_ops.Dataset.range(10).shard(4, 3)
    102     iterator = dataset.make_one_shot_iterator()
    103     with self.test_session() as sess:
    104       self.assertEqual(3, sess.run(iterator.get_next()))
    105       self.assertEqual(7, sess.run(iterator.get_next()))
    106       with self.assertRaises(errors.OutOfRangeError):
    107         sess.run(iterator.get_next())
    108 
    109 
    110 if __name__ == "__main__":
    111   test.main()
    112