1 # Copyright 2017 The TensorFlow Authors. All Rights Reserved. 2 # 3 # Licensed under the Apache License, Version 2.0 (the "License"); 4 # you may not use this file except in compliance with the License. 5 # You may obtain a copy of the License at 6 # 7 # http://www.apache.org/licenses/LICENSE-2.0 8 # 9 # Unless required by applicable law or agreed to in writing, software 10 # distributed under the License is distributed on an "AS IS" BASIS, 11 # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. 12 # See the License for the specific language governing permissions and 13 # limitations under the License. 14 # ============================================================================== 15 """Tests for the experimental input pipeline ops.""" 16 from __future__ import absolute_import 17 from __future__ import division 18 from __future__ import print_function 19 20 from tensorflow.python.data.ops import dataset_ops 21 from tensorflow.python.framework import errors 22 from tensorflow.python.platform import test 23 24 25 class ShardDatasetOpTest(test.TestCase): 26 27 def testSimpleCase(self): 28 dataset = dataset_ops.Dataset.range(10).shard(5, 2) 29 iterator = dataset.make_one_shot_iterator() 30 31 with self.test_session() as sess: 32 self.assertEqual(2, sess.run(iterator.get_next())) 33 self.assertEqual(7, sess.run(iterator.get_next())) 34 with self.assertRaises(errors.OutOfRangeError): 35 sess.run(iterator.get_next()) 36 37 def testNestedData(self): 38 dataset_a = dataset_ops.Dataset.range(10) 39 dataset_b = dataset_ops.Dataset.range(10, 0, -1) 40 dataset = dataset_ops.Dataset.zip((dataset_a, dataset_b)).shard(5, 2) 41 iterator = dataset.make_one_shot_iterator() 42 43 with self.test_session() as sess: 44 self.assertEqual((2, 8), sess.run(iterator.get_next())) 45 self.assertEqual((7, 3), sess.run(iterator.get_next())) 46 with self.assertRaises(errors.OutOfRangeError): 47 sess.run(iterator.get_next()) 48 49 def testOffsetZero(self): 50 dataset = dataset_ops.Dataset.range(10).shard(5, 0) 51 iterator = dataset.make_one_shot_iterator() 52 53 with self.test_session() as sess: 54 self.assertEqual(0, sess.run(iterator.get_next())) 55 self.assertEqual(5, sess.run(iterator.get_next())) 56 with self.assertRaises(errors.OutOfRangeError): 57 sess.run(iterator.get_next()) 58 59 def testOffsetGreaterNumShards(self): 60 with self.assertRaises(ValueError): 61 dataset_ops.Dataset.range(10).shard(5, 7) 62 63 def testNegativeOffset(self): 64 with self.assertRaises(ValueError): 65 dataset_ops.Dataset.range(10).shard(5, -3) 66 67 def testNegativeNumShards(self): 68 with self.assertRaises(ValueError): 69 dataset_ops.Dataset.range(10).shard(-3, 1) 70 71 def testZeroNumShards(self): 72 with self.assertRaises(ValueError): 73 dataset_ops.Dataset.range(10).shard(0, 1) 74 75 def testIteratorEndsBeforeFirstElem(self): 76 dataset = dataset_ops.Dataset.range(1).shard(5, 2) 77 iterator = dataset.make_one_shot_iterator() 78 79 with self.test_session() as sess: 80 with self.assertRaises(errors.OutOfRangeError): 81 sess.run(iterator.get_next()) 82 83 def testLargerWorkerPool(self): 84 dataset = dataset_ops.Dataset.range(10).shard(7, 5) 85 iterator = dataset.make_one_shot_iterator() 86 with self.test_session() as sess: 87 self.assertEqual(5, sess.run(iterator.get_next())) 88 with self.assertRaises(errors.OutOfRangeError): 89 sess.run(iterator.get_next()) 90 91 def testIndexEqualsNumShards(self): 92 dataset = dataset_ops.Dataset.range(10).shard(5, 4) 93 iterator = dataset.make_one_shot_iterator() 94 with self.test_session() as sess: 95 self.assertEqual(4, sess.run(iterator.get_next())) 96 self.assertEqual(9, sess.run(iterator.get_next())) 97 with self.assertRaises(errors.OutOfRangeError): 98 sess.run(iterator.get_next()) 99 100 def testIndexEqualsNumShards2(self): 101 dataset = dataset_ops.Dataset.range(10).shard(4, 3) 102 iterator = dataset.make_one_shot_iterator() 103 with self.test_session() as sess: 104 self.assertEqual(3, sess.run(iterator.get_next())) 105 self.assertEqual(7, sess.run(iterator.get_next())) 106 with self.assertRaises(errors.OutOfRangeError): 107 sess.run(iterator.get_next()) 108 109 110 if __name__ == "__main__": 111 test.main() 112